mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 20:07:58 +01:00
merge tts and t2s into NeuralSeq
This commit is contained in:
@@ -110,11 +110,6 @@ class BasePreprocessor:
|
|||||||
f.writelines([f'{l}\n' for l in mfa_dict])
|
f.writelines([f'{l}\n' for l in mfa_dict])
|
||||||
with open(f"{processed_dir}/{self.meta_csv_filename}.json", 'w') as f:
|
with open(f"{processed_dir}/{self.meta_csv_filename}.json", 'w') as f:
|
||||||
f.write(re.sub(r'\n\s+([\d+\]])', r'\1', json.dumps(items, ensure_ascii=False, sort_keys=False, indent=1)))
|
f.write(re.sub(r'\n\s+([\d+\]])', r'\1', json.dumps(items, ensure_ascii=False, sort_keys=False, indent=1)))
|
||||||
|
|
||||||
# save to csv
|
|
||||||
meta_df = pd.DataFrame(items)
|
|
||||||
meta_df.to_csv(f"{processed_dir}/metadata_phone.csv")
|
|
||||||
|
|
||||||
remove_file(wav_processed_tmp_dir)
|
remove_file(wav_processed_tmp_dir)
|
||||||
|
|
||||||
|
|
||||||
@@ -6,9 +6,8 @@ from g2p_en.expand import normalize_numbers
|
|||||||
from nltk import pos_tag
|
from nltk import pos_tag
|
||||||
from nltk.tokenize import TweetTokenizer
|
from nltk.tokenize import TweetTokenizer
|
||||||
|
|
||||||
from text_to_speech.data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor, register_txt_processors
|
from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor, register_txt_processors
|
||||||
from text_to_speech.utils.text.text_encoder import PUNCS, is_sil_phoneme
|
from data_gen.tts.data_gen_utils import is_sil_phoneme, PUNCS
|
||||||
|
|
||||||
|
|
||||||
class EnG2p(G2p):
|
class EnG2p(G2p):
|
||||||
word_tokenize = TweetTokenizer().tokenize
|
word_tokenize = TweetTokenizer().tokenize
|
||||||
@@ -75,4 +74,4 @@ class TxtProcessor(BaseTxtProcessor):
|
|||||||
else:
|
else:
|
||||||
txt_struct[i_word][1].append(p)
|
txt_struct[i_word][1].append(p)
|
||||||
txt_struct = cls.postprocess(txt_struct, preprocess_args)
|
txt_struct = cls.postprocess(txt_struct, preprocess_args)
|
||||||
return txt_struct, txt
|
return txt_struct, txt
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
import jieba
|
||||||
from pypinyin import pinyin, Style
|
from pypinyin import pinyin, Style
|
||||||
from data_gen.tts.data_gen_utils import PUNCS
|
from data_gen.tts.data_gen_utils import PUNCS
|
||||||
from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
|
from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
|
||||||
@@ -20,6 +21,7 @@ class TxtProcessor(BaseTxtProcessor):
|
|||||||
text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
|
text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
|
||||||
text = re.sub(f"([{PUNCS}])", r" \1 ", text)
|
text = re.sub(f"([{PUNCS}])", r" \1 ", text)
|
||||||
text = re.sub(rf"\s+", r"", text)
|
text = re.sub(rf"\s+", r"", text)
|
||||||
|
text = re.sub(rf"[A-Za-z]+", r"$", text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -2,11 +2,10 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from text_to_speech.data_gen.tts.wav_processors.base_processor import BaseWavProcessor, register_wav_processors
|
from data_gen.tts.wav_processors.base_processor import BaseWavProcessor, register_wav_processors
|
||||||
from text_to_speech.utils.audio import trim_long_silences
|
from data_gen.tts.data_gen_utils import trim_long_silences
|
||||||
from text_to_speech.utils.audio.io import save_wav
|
from utils.audio import save_wav, rnnoise
|
||||||
from text_to_speech.utils.audio.rnnoise import rnnoise
|
from utils.hparams import hparams
|
||||||
from text_to_speech.utils.commons.hparams import hparams
|
|
||||||
|
|
||||||
|
|
||||||
@register_wav_processors(name='sox_to_wav')
|
@register_wav_processors(name='sox_to_wav')
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
# from inference.tts.fs import FastSpeechInfer
|
|
||||||
# from modules.tts.fs2_orig import FastSpeech2Orig
|
|
||||||
from inference.svs.base_svs_infer import BaseSVSInfer
|
from inference.svs.base_svs_infer import BaseSVSInfer
|
||||||
from utils import load_ckpt
|
from utils import load_ckpt
|
||||||
from utils.hparams import hparams
|
from utils.hparams import hparams
|
||||||
@@ -1,66 +1,18 @@
|
|||||||
from data_gen.tts.data_gen_utils import is_sil_phoneme
|
import torch
|
||||||
from resemblyzer import VoiceEncoder
|
from inference.tts.base_tts_infer import BaseTTSInfer
|
||||||
from data_gen.tts.data_gen_utils import build_phone_encoder, build_word_encoder
|
from utils.ckpt_utils import load_ckpt, get_last_checkpoint
|
||||||
from tasks.tts.dataset_utils import FastSpeechWordDataset
|
from modules.GenerSpeech.model.generspeech import GenerSpeech
|
||||||
from tasks.tts.tts_utils import load_data_preprocessor
|
import os
|
||||||
from vocoders.hifigan import HifiGanGenerator
|
|
||||||
from data_gen.tts.emotion import inference as EmotionEncoder
|
from data_gen.tts.emotion import inference as EmotionEncoder
|
||||||
from data_gen.tts.emotion.inference import embed_utterance as Embed_utterance
|
from data_gen.tts.emotion.inference import embed_utterance as Embed_utterance
|
||||||
from data_gen.tts.emotion.inference import preprocess_wav
|
from data_gen.tts.emotion.inference import preprocess_wav
|
||||||
import importlib
|
|
||||||
import os
|
|
||||||
import librosa
|
|
||||||
import soundfile as sf
|
|
||||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|
||||||
from string import punctuation
|
|
||||||
import torch
|
|
||||||
from utils import audio
|
|
||||||
from utils.ckpt_utils import load_ckpt
|
|
||||||
from utils.hparams import set_hparams
|
|
||||||
from utils.hparams import hparams as hp
|
|
||||||
|
|
||||||
class BaseTTSInfer:
|
|
||||||
def __init__(self, hparams, device=None):
|
|
||||||
if device is None:
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
self.hparams = hparams
|
|
||||||
self.device = device
|
|
||||||
self.data_dir = hparams['binary_data_dir']
|
|
||||||
self.preprocessor, self.preprocess_args = load_data_preprocessor()
|
|
||||||
self.ph_encoder, self.word_encoder = self.preprocessor.load_dict(self.data_dir)
|
|
||||||
self.ds_cls = FastSpeechWordDataset
|
|
||||||
self.model = self.build_model()
|
|
||||||
self.model.eval()
|
|
||||||
self.model.to(self.device)
|
|
||||||
self.vocoder = self.build_vocoder()
|
|
||||||
self.vocoder.eval()
|
|
||||||
self.vocoder.to(self.device)
|
|
||||||
self.asr_processor, self.asr_model = self.build_asr()
|
|
||||||
|
|
||||||
|
class GenerSpeechInfer(BaseTTSInfer):
|
||||||
def build_model(self):
|
def build_model(self):
|
||||||
raise NotImplementedError
|
model = GenerSpeech(self.ph_encoder)
|
||||||
|
model.eval()
|
||||||
def forward_model(self, inp):
|
load_ckpt(model, self.hparams['work_dir'], 'model')
|
||||||
raise NotImplementedError
|
return model
|
||||||
|
|
||||||
def build_asr(self):
|
|
||||||
# load pretrained model
|
|
||||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") # facebook/wav2vec2-base-960h wav2vec2-large-960h-lv60-self
|
|
||||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(self.device)
|
|
||||||
return processor, model
|
|
||||||
|
|
||||||
def build_vocoder(self):
|
|
||||||
base_dir = self.hparams['vocoder_ckpt']
|
|
||||||
config_path = f'{base_dir}/config.yaml'
|
|
||||||
config = set_hparams(config_path, global_hparams=False)
|
|
||||||
vocoder = HifiGanGenerator(config)
|
|
||||||
load_ckpt(vocoder, base_dir, 'model_gen')
|
|
||||||
return vocoder
|
|
||||||
|
|
||||||
def run_vocoder(self, c):
|
|
||||||
c = c.transpose(2, 1)
|
|
||||||
y = self.vocoder(c)[:, 0]
|
|
||||||
return y
|
|
||||||
|
|
||||||
def preprocess_input(self, inp):
|
def preprocess_input(self, inp):
|
||||||
"""
|
"""
|
||||||
@@ -146,42 +98,23 @@ class BaseTTSInfer:
|
|||||||
}
|
}
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def postprocess_output(self, output):
|
def forward_model(self, inp):
|
||||||
return output
|
sample = self.input_to_batch(inp)
|
||||||
|
txt_tokens = sample['txt_tokens'] # [B, T_t]
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model(txt_tokens, ref_mel2ph=sample['mel2ph'], ref_mel2word=sample['mel2word'], ref_mels=sample['mels'],
|
||||||
|
spk_embed=sample['spk_embed'], emo_embed=sample['emo_embed'], global_steps=300000, infer=True)
|
||||||
|
mel_out = output['mel_out']
|
||||||
|
wav_out = self.run_vocoder(mel_out)
|
||||||
|
wav_out = wav_out.squeeze().cpu().numpy()
|
||||||
|
return wav_out
|
||||||
|
|
||||||
def infer_once(self, inp):
|
|
||||||
inp = self.preprocess_input(inp)
|
|
||||||
output = self.forward_model(inp)
|
|
||||||
output = self.postprocess_output(output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def example_run(cls, inp):
|
|
||||||
from utils.audio import save_wav
|
|
||||||
|
|
||||||
#set_hparams(print_hparams=False)
|
|
||||||
infer_ins = cls(hp)
|
|
||||||
out = infer_ins.infer_once(inp)
|
|
||||||
os.makedirs('infer_out', exist_ok=True)
|
|
||||||
save_wav(out, f'infer_out/{hp["text"]}.wav', hp['audio_sample_rate'])
|
|
||||||
print(f'Save at infer_out/{hp["text"]}.wav.')
|
|
||||||
|
|
||||||
def asr(self, file):
|
if __name__ == '__main__':
|
||||||
sample_rate = self.hparams['audio_sample_rate']
|
inp = {
|
||||||
audio_input, source_sample_rate = sf.read(file)
|
'text': 'here we go',
|
||||||
|
'ref_audio': 'assets/0011_001570.wav'
|
||||||
# Resample the wav if needed
|
}
|
||||||
if sample_rate is not None and source_sample_rate != sample_rate:
|
GenerSpeechInfer.example_run(inp)
|
||||||
audio_input = librosa.resample(audio_input, source_sample_rate, sample_rate)
|
|
||||||
|
|
||||||
# pad input values and return pt tensor
|
|
||||||
input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
|
|
||||||
|
|
||||||
# retrieve logits & take argmax
|
|
||||||
logits = self.asr_model(input_values.cuda()).logits
|
|
||||||
predicted_ids = torch.argmax(logits, dim=-1)
|
|
||||||
|
|
||||||
# transcribe
|
|
||||||
transcription = self.asr_processor.decode(predicted_ids[0])
|
|
||||||
transcription = transcription.rstrip(punctuation)
|
|
||||||
return audio_input, transcription
|
|
||||||
104
NeuralSeq/inference/tts/base_tts_infer.py
Normal file
104
NeuralSeq/inference/tts/base_tts_infer.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from data_gen.tts.data_gen_utils import is_sil_phoneme
|
||||||
|
from resemblyzer import VoiceEncoder
|
||||||
|
from data_gen.tts.data_gen_utils import build_phone_encoder, build_word_encoder
|
||||||
|
from tasks.tts.dataset_utils import FastSpeechWordDataset
|
||||||
|
from tasks.tts.tts_utils import load_data_preprocessor
|
||||||
|
from vocoders.hifigan import HifiGanGenerator
|
||||||
|
import os
|
||||||
|
import librosa
|
||||||
|
import soundfile as sf
|
||||||
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||||
|
from string import punctuation
|
||||||
|
import torch
|
||||||
|
from utils.ckpt_utils import load_ckpt
|
||||||
|
from utils.hparams import set_hparams
|
||||||
|
from utils.hparams import hparams as hp
|
||||||
|
|
||||||
|
class BaseTTSInfer:
|
||||||
|
def __init__(self, hparams, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
self.hparams = hparams
|
||||||
|
self.device = device
|
||||||
|
self.data_dir = hparams['binary_data_dir']
|
||||||
|
self.preprocessor, self.preprocess_args = load_data_preprocessor()
|
||||||
|
self.ph_encoder, self.word_encoder = self.preprocessor.load_dict(self.data_dir)
|
||||||
|
self.ds_cls = FastSpeechWordDataset
|
||||||
|
self.model = self.build_model()
|
||||||
|
self.model.eval()
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.vocoder = self.build_vocoder()
|
||||||
|
self.vocoder.eval()
|
||||||
|
self.vocoder.to(self.device)
|
||||||
|
self.asr_processor, self.asr_model = self.build_asr()
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward_model(self, inp):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def build_asr(self):
|
||||||
|
# load pretrained model
|
||||||
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") # facebook/wav2vec2-base-960h wav2vec2-large-960h-lv60-self
|
||||||
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(self.device)
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
def build_vocoder(self):
|
||||||
|
base_dir = self.hparams['vocoder_ckpt']
|
||||||
|
config_path = f'{base_dir}/config.yaml'
|
||||||
|
config = set_hparams(config_path, global_hparams=False)
|
||||||
|
vocoder = HifiGanGenerator(config)
|
||||||
|
load_ckpt(vocoder, base_dir, 'model_gen')
|
||||||
|
return vocoder
|
||||||
|
|
||||||
|
def run_vocoder(self, c):
|
||||||
|
c = c.transpose(2, 1)
|
||||||
|
y = self.vocoder(c)[:, 0]
|
||||||
|
return y
|
||||||
|
|
||||||
|
def preprocess_input(self, inp):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def input_to_batch(self, item):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def postprocess_output(self, output):
|
||||||
|
return output
|
||||||
|
|
||||||
|
def infer_once(self, inp):
|
||||||
|
inp = self.preprocess_input(inp)
|
||||||
|
output = self.forward_model(inp)
|
||||||
|
output = self.postprocess_output(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def example_run(cls, inp):
|
||||||
|
from utils.audio import save_wav
|
||||||
|
|
||||||
|
#set_hparams(print_hparams=False)
|
||||||
|
infer_ins = cls(hp)
|
||||||
|
out = infer_ins.infer_once(inp)
|
||||||
|
os.makedirs('infer_out', exist_ok=True)
|
||||||
|
save_wav(out, f'infer_out/{hp["text"]}.wav', hp['audio_sample_rate'])
|
||||||
|
print(f'Save at infer_out/{hp["text"]}.wav.')
|
||||||
|
|
||||||
|
def asr(self, file):
|
||||||
|
sample_rate = self.hparams['audio_sample_rate']
|
||||||
|
audio_input, source_sample_rate = sf.read(file)
|
||||||
|
|
||||||
|
# Resample the wav if needed
|
||||||
|
if sample_rate is not None and source_sample_rate != sample_rate:
|
||||||
|
audio_input = librosa.resample(audio_input, source_sample_rate, sample_rate)
|
||||||
|
|
||||||
|
# pad input values and return pt tensor
|
||||||
|
input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
|
||||||
|
|
||||||
|
# retrieve logits & take argmax
|
||||||
|
logits = self.asr_model(input_values.cuda()).logits
|
||||||
|
predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
|
|
||||||
|
# transcribe
|
||||||
|
transcription = self.asr_processor.decode(predicted_ids[0])
|
||||||
|
transcription = transcription.rstrip(punctuation)
|
||||||
|
return audio_input, transcription
|
||||||
@@ -76,7 +76,26 @@ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False)
|
|||||||
pass
|
pass
|
||||||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
||||||
|
|
||||||
|
class LayerNorm_(torch.nn.LayerNorm):
|
||||||
|
"""Layer normalization module.
|
||||||
|
:param int nout: output dim size
|
||||||
|
:param int dim: dimension to be normalized
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, nout, dim=-1, eps=1e-5):
|
||||||
|
"""Construct an LayerNorm object."""
|
||||||
|
super(LayerNorm_, self).__init__(nout, eps=eps)
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Apply layer normalization.
|
||||||
|
:param torch.Tensor x: input tensor
|
||||||
|
:return: layer normalized tensor
|
||||||
|
:rtype torch.Tensor
|
||||||
|
"""
|
||||||
|
if self.dim == -1:
|
||||||
|
return super(LayerNorm_, self).forward(x)
|
||||||
|
return super(LayerNorm_, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
||||||
def Linear(in_features, out_features, bias=True):
|
def Linear(in_features, out_features, bias=True):
|
||||||
m = nn.Linear(in_features, out_features, bias)
|
m = nn.Linear(in_features, out_features, bias)
|
||||||
nn.init.xavier_uniform_(m.weight)
|
nn.init.xavier_uniform_(m.weight)
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
|
from utils.hparams import hparams
|
||||||
from modules.commons.common_layers import *
|
from modules.commons.common_layers import *
|
||||||
from modules.commons.common_layers import Embedding
|
from modules.commons.common_layers import Embedding
|
||||||
from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
|
from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
|
||||||
EnergyPredictor, FastspeechEncoder
|
EnergyPredictor, FastspeechEncoder
|
||||||
from utils.cwt import cwt2f0
|
from utils.cwt import cwt2f0
|
||||||
from utils.hparams import hparams
|
|
||||||
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
|
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
|
||||||
|
import torch.nn as nn
|
||||||
|
from modules.commons.rel_transformer import RelTransformerEncoder, BERTRelTransformerEncoder
|
||||||
FS_ENCODERS = {
|
FS_ENCODERS = {
|
||||||
'fft': lambda hp, embed_tokens, d: FastspeechEncoder(
|
'fft': lambda hp, embed_tokens, d: FastspeechEncoder(
|
||||||
embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
|
embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
|
||||||
@@ -27,11 +28,14 @@ class FastSpeech2(nn.Module):
|
|||||||
self.dec_layers = hparams['dec_layers']
|
self.dec_layers = hparams['dec_layers']
|
||||||
self.hidden_size = hparams['hidden_size']
|
self.hidden_size = hparams['hidden_size']
|
||||||
self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
|
self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
|
||||||
self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
|
if hparams.get("use_bert", False):
|
||||||
|
self.ph_encoder = BERTRelTransformerEncoder(len(self.dictionary), hparams['hidden_size'], hparams['hidden_size'],
|
||||||
|
hparams['ffn_hidden_size'], hparams['num_heads'], hparams['enc_layers'],
|
||||||
|
hparams['enc_ffn_kernel_size'], hparams['dropout'], prenet=hparams['enc_prenet'], pre_ln=hparams['enc_pre_ln'])
|
||||||
|
else:
|
||||||
|
self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
|
||||||
self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
|
self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
|
||||||
self.out_dims = out_dims
|
self.out_dims = hparams['audio_num_mel_bins'] if out_dims is None else out_dims
|
||||||
if out_dims is None:
|
|
||||||
self.out_dims = hparams['audio_num_mel_bins']
|
|
||||||
self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
|
self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
|
||||||
|
|
||||||
if hparams['use_spk_id']:
|
if hparams['use_spk_id']:
|
||||||
@@ -46,44 +50,26 @@ class FastSpeech2(nn.Module):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
n_chans=predictor_hidden,
|
n_chans=predictor_hidden,
|
||||||
n_layers=hparams['dur_predictor_layers'],
|
n_layers=hparams['dur_predictor_layers'],
|
||||||
dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
|
dropout_rate=hparams['predictor_dropout'],
|
||||||
kernel_size=hparams['dur_predictor_kernel'])
|
kernel_size=hparams['dur_predictor_kernel'])
|
||||||
self.length_regulator = LengthRegulator()
|
self.length_regulator = LengthRegulator()
|
||||||
if hparams['use_pitch_embed']:
|
if hparams['use_pitch_embed']:
|
||||||
self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
|
self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
|
||||||
if hparams['pitch_type'] == 'cwt':
|
self.pitch_predictor = PitchPredictor(
|
||||||
h = hparams['cwt_hidden_size']
|
self.hidden_size,
|
||||||
cwt_out_dims = 10
|
n_chans=predictor_hidden,
|
||||||
if hparams['use_uv']:
|
n_layers=hparams['predictor_layers'],
|
||||||
cwt_out_dims = cwt_out_dims + 1
|
dropout_rate=hparams['predictor_dropout'],
|
||||||
self.cwt_predictor = nn.Sequential(
|
odim=2 if hparams['pitch_type'] == 'frame' else 1,
|
||||||
nn.Linear(self.hidden_size, h),
|
kernel_size=hparams['predictor_kernel'])
|
||||||
PitchPredictor(
|
if hparams.get('use_energy_embed', False):
|
||||||
h,
|
|
||||||
n_chans=predictor_hidden,
|
|
||||||
n_layers=hparams['predictor_layers'],
|
|
||||||
dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
|
|
||||||
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
|
|
||||||
self.cwt_stats_layers = nn.Sequential(
|
|
||||||
nn.Linear(self.hidden_size, h), nn.ReLU(),
|
|
||||||
nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.pitch_predictor = PitchPredictor(
|
|
||||||
self.hidden_size,
|
|
||||||
n_chans=predictor_hidden,
|
|
||||||
n_layers=hparams['predictor_layers'],
|
|
||||||
dropout_rate=hparams['predictor_dropout'],
|
|
||||||
odim=2 if hparams['pitch_type'] == 'frame' else 1,
|
|
||||||
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
|
||||||
if hparams['use_energy_embed']:
|
|
||||||
self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
|
self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
|
||||||
self.energy_predictor = EnergyPredictor(
|
self.energy_predictor = EnergyPredictor(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
n_chans=predictor_hidden,
|
n_chans=predictor_hidden,
|
||||||
n_layers=hparams['predictor_layers'],
|
n_layers=hparams['predictor_layers'],
|
||||||
dropout_rate=hparams['predictor_dropout'], odim=1,
|
dropout_rate=hparams['predictor_dropout'], odim=1,
|
||||||
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
kernel_size=hparams['predictor_kernel'])
|
||||||
|
|
||||||
def build_embedding(self, dictionary, embed_dim):
|
def build_embedding(self, dictionary, embed_dim):
|
||||||
num_embeddings = len(dictionary)
|
num_embeddings = len(dictionary)
|
||||||
@@ -94,7 +80,10 @@ class FastSpeech2(nn.Module):
|
|||||||
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
||||||
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
||||||
ret = {}
|
ret = {}
|
||||||
encoder_out = self.encoder(txt_tokens) # [B, T, C]
|
if hparams.get("use_bert", False):
|
||||||
|
encoder_out = self.encoder(txt_tokens, bert_feats=kwargs['bert_feats'], ph2word=kwargs['ph2word'], ret=ret)
|
||||||
|
else:
|
||||||
|
encoder_out = self.encoder(txt_tokens) # [B, T, C]
|
||||||
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
||||||
|
|
||||||
# add ref style embed
|
# add ref style embed
|
||||||
@@ -137,7 +126,7 @@ class FastSpeech2(nn.Module):
|
|||||||
if hparams['use_pitch_embed']:
|
if hparams['use_pitch_embed']:
|
||||||
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
||||||
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
||||||
if hparams['use_energy_embed']:
|
if hparams.get('use_energy_embed', False):
|
||||||
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
||||||
|
|
||||||
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
||||||
@@ -67,7 +67,7 @@ class DurationPredictor(torch.nn.Module):
|
|||||||
the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
|
the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
|
def __init__(self, idim, odims = 1, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
|
||||||
"""Initilize duration predictor module.
|
"""Initilize duration predictor module.
|
||||||
Args:
|
Args:
|
||||||
idim (int): Input dimension.
|
idim (int): Input dimension.
|
||||||
@@ -93,14 +93,6 @@ class DurationPredictor(torch.nn.Module):
|
|||||||
LayerNorm(n_chans, dim=1),
|
LayerNorm(n_chans, dim=1),
|
||||||
torch.nn.Dropout(dropout_rate)
|
torch.nn.Dropout(dropout_rate)
|
||||||
)]
|
)]
|
||||||
if hparams['dur_loss'] in ['mse', 'huber']:
|
|
||||||
odims = 1
|
|
||||||
elif hparams['dur_loss'] == 'mog':
|
|
||||||
odims = 15
|
|
||||||
elif hparams['dur_loss'] == 'crf':
|
|
||||||
odims = 32
|
|
||||||
from torchcrf import CRF
|
|
||||||
self.crf = CRF(odims, batch_first=True)
|
|
||||||
self.linear = torch.nn.Linear(n_chans, odims)
|
self.linear = torch.nn.Linear(n_chans, odims)
|
||||||
|
|
||||||
def _forward(self, xs, x_masks=None, is_inference=False):
|
def _forward(self, xs, x_masks=None, is_inference=False):
|
||||||
@@ -150,6 +142,39 @@ class DurationPredictor(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
return self._forward(xs, x_masks, True)
|
return self._forward(xs, x_masks, True)
|
||||||
|
|
||||||
|
class SyntaDurationPredictor(torch.nn.Module):
|
||||||
|
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0):
|
||||||
|
super(SyntaDurationPredictor, self).__init__()
|
||||||
|
from modules.syntaspeech.syntactic_graph_encoder import GraphAuxEnc
|
||||||
|
self.graph_encoder = GraphAuxEnc(in_dim=idim, hid_dim=idim, out_dim=idim)
|
||||||
|
self.offset = offset
|
||||||
|
self.conv = torch.nn.ModuleList()
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
for idx in range(n_layers):
|
||||||
|
in_chans = idim if idx == 0 else n_chans
|
||||||
|
self.conv += [torch.nn.Sequential(
|
||||||
|
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
LayerNorm(n_chans, dim=1),
|
||||||
|
torch.nn.Dropout(dropout_rate)
|
||||||
|
)]
|
||||||
|
self.linear = nn.Sequential(torch.nn.Linear(n_chans, 1), nn.Softplus())
|
||||||
|
|
||||||
|
def forward(self, x, x_padding=None, ph2word=None, graph_lst=None, etypes_lst=None):
|
||||||
|
x = x.transpose(1, -1) # (B, idim, Tmax)
|
||||||
|
assert ph2word is not None and graph_lst is not None and etypes_lst is not None
|
||||||
|
x_graph = self.graph_encoder(graph_lst, x, ph2word, etypes_lst)
|
||||||
|
x = x + x_graph * 1.
|
||||||
|
|
||||||
|
for f in self.conv:
|
||||||
|
x = f(x) # (B, C, Tmax)
|
||||||
|
if x_padding is not None:
|
||||||
|
x = x * (1 - x_padding.float())[:, None, :]
|
||||||
|
|
||||||
|
x = self.linear(x.transpose(1, -1)) # [B, T, C]
|
||||||
|
x = x * (1 - x_padding.float())[:, :, None] # (B, T, C)
|
||||||
|
x = x[..., 0] # (B, Tmax)
|
||||||
|
return x
|
||||||
|
|
||||||
class LengthRegulator(torch.nn.Module):
|
class LengthRegulator(torch.nn.Module):
|
||||||
def __init__(self, pad_value=0.0):
|
def __init__(self, pad_value=0.0):
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user