mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 20:07:58 +01:00
merge tts and t2s into NeuralSeq
This commit is contained in:
@@ -110,11 +110,6 @@ class BasePreprocessor:
|
||||
f.writelines([f'{l}\n' for l in mfa_dict])
|
||||
with open(f"{processed_dir}/{self.meta_csv_filename}.json", 'w') as f:
|
||||
f.write(re.sub(r'\n\s+([\d+\]])', r'\1', json.dumps(items, ensure_ascii=False, sort_keys=False, indent=1)))
|
||||
|
||||
# save to csv
|
||||
meta_df = pd.DataFrame(items)
|
||||
meta_df.to_csv(f"{processed_dir}/metadata_phone.csv")
|
||||
|
||||
remove_file(wav_processed_tmp_dir)
|
||||
|
||||
|
||||
@@ -6,9 +6,8 @@ from g2p_en.expand import normalize_numbers
|
||||
from nltk import pos_tag
|
||||
from nltk.tokenize import TweetTokenizer
|
||||
|
||||
from text_to_speech.data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor, register_txt_processors
|
||||
from text_to_speech.utils.text.text_encoder import PUNCS, is_sil_phoneme
|
||||
|
||||
from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor, register_txt_processors
|
||||
from data_gen.tts.data_gen_utils import is_sil_phoneme, PUNCS
|
||||
|
||||
class EnG2p(G2p):
|
||||
word_tokenize = TweetTokenizer().tokenize
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
import jieba
|
||||
from pypinyin import pinyin, Style
|
||||
from data_gen.tts.data_gen_utils import PUNCS
|
||||
from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
|
||||
@@ -20,6 +21,7 @@ class TxtProcessor(BaseTxtProcessor):
|
||||
text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
|
||||
text = re.sub(f"([{PUNCS}])", r" \1 ", text)
|
||||
text = re.sub(rf"\s+", r"", text)
|
||||
text = re.sub(rf"[A-Za-z]+", r"$", text)
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
@@ -2,11 +2,10 @@ import os
|
||||
import subprocess
|
||||
import librosa
|
||||
import numpy as np
|
||||
from text_to_speech.data_gen.tts.wav_processors.base_processor import BaseWavProcessor, register_wav_processors
|
||||
from text_to_speech.utils.audio import trim_long_silences
|
||||
from text_to_speech.utils.audio.io import save_wav
|
||||
from text_to_speech.utils.audio.rnnoise import rnnoise
|
||||
from text_to_speech.utils.commons.hparams import hparams
|
||||
from data_gen.tts.wav_processors.base_processor import BaseWavProcessor, register_wav_processors
|
||||
from data_gen.tts.data_gen_utils import trim_long_silences
|
||||
from utils.audio import save_wav, rnnoise
|
||||
from utils.hparams import hparams
|
||||
|
||||
|
||||
@register_wav_processors(name='sox_to_wav')
|
||||
@@ -1,6 +1,4 @@
|
||||
import torch
|
||||
# from inference.tts.fs import FastSpeechInfer
|
||||
# from modules.tts.fs2_orig import FastSpeech2Orig
|
||||
from inference.svs.base_svs_infer import BaseSVSInfer
|
||||
from utils import load_ckpt
|
||||
from utils.hparams import hparams
|
||||
@@ -1,66 +1,18 @@
|
||||
from data_gen.tts.data_gen_utils import is_sil_phoneme
|
||||
from resemblyzer import VoiceEncoder
|
||||
from data_gen.tts.data_gen_utils import build_phone_encoder, build_word_encoder
|
||||
from tasks.tts.dataset_utils import FastSpeechWordDataset
|
||||
from tasks.tts.tts_utils import load_data_preprocessor
|
||||
from vocoders.hifigan import HifiGanGenerator
|
||||
import torch
|
||||
from inference.tts.base_tts_infer import BaseTTSInfer
|
||||
from utils.ckpt_utils import load_ckpt, get_last_checkpoint
|
||||
from modules.GenerSpeech.model.generspeech import GenerSpeech
|
||||
import os
|
||||
from data_gen.tts.emotion import inference as EmotionEncoder
|
||||
from data_gen.tts.emotion.inference import embed_utterance as Embed_utterance
|
||||
from data_gen.tts.emotion.inference import preprocess_wav
|
||||
import importlib
|
||||
import os
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
from string import punctuation
|
||||
import torch
|
||||
from utils import audio
|
||||
from utils.ckpt_utils import load_ckpt
|
||||
from utils.hparams import set_hparams
|
||||
from utils.hparams import hparams as hp
|
||||
|
||||
class BaseTTSInfer:
|
||||
def __init__(self, hparams, device=None):
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.hparams = hparams
|
||||
self.device = device
|
||||
self.data_dir = hparams['binary_data_dir']
|
||||
self.preprocessor, self.preprocess_args = load_data_preprocessor()
|
||||
self.ph_encoder, self.word_encoder = self.preprocessor.load_dict(self.data_dir)
|
||||
self.ds_cls = FastSpeechWordDataset
|
||||
self.model = self.build_model()
|
||||
self.model.eval()
|
||||
self.model.to(self.device)
|
||||
self.vocoder = self.build_vocoder()
|
||||
self.vocoder.eval()
|
||||
self.vocoder.to(self.device)
|
||||
self.asr_processor, self.asr_model = self.build_asr()
|
||||
|
||||
class GenerSpeechInfer(BaseTTSInfer):
|
||||
def build_model(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_model(self, inp):
|
||||
raise NotImplementedError
|
||||
|
||||
def build_asr(self):
|
||||
# load pretrained model
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") # facebook/wav2vec2-base-960h wav2vec2-large-960h-lv60-self
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(self.device)
|
||||
return processor, model
|
||||
|
||||
def build_vocoder(self):
|
||||
base_dir = self.hparams['vocoder_ckpt']
|
||||
config_path = f'{base_dir}/config.yaml'
|
||||
config = set_hparams(config_path, global_hparams=False)
|
||||
vocoder = HifiGanGenerator(config)
|
||||
load_ckpt(vocoder, base_dir, 'model_gen')
|
||||
return vocoder
|
||||
|
||||
def run_vocoder(self, c):
|
||||
c = c.transpose(2, 1)
|
||||
y = self.vocoder(c)[:, 0]
|
||||
return y
|
||||
model = GenerSpeech(self.ph_encoder)
|
||||
model.eval()
|
||||
load_ckpt(model, self.hparams['work_dir'], 'model')
|
||||
return model
|
||||
|
||||
def preprocess_input(self, inp):
|
||||
"""
|
||||
@@ -146,42 +98,23 @@ class BaseTTSInfer:
|
||||
}
|
||||
return batch
|
||||
|
||||
def postprocess_output(self, output):
|
||||
return output
|
||||
def forward_model(self, inp):
|
||||
sample = self.input_to_batch(inp)
|
||||
txt_tokens = sample['txt_tokens'] # [B, T_t]
|
||||
with torch.no_grad():
|
||||
output = self.model(txt_tokens, ref_mel2ph=sample['mel2ph'], ref_mel2word=sample['mel2word'], ref_mels=sample['mels'],
|
||||
spk_embed=sample['spk_embed'], emo_embed=sample['emo_embed'], global_steps=300000, infer=True)
|
||||
mel_out = output['mel_out']
|
||||
wav_out = self.run_vocoder(mel_out)
|
||||
wav_out = wav_out.squeeze().cpu().numpy()
|
||||
return wav_out
|
||||
|
||||
def infer_once(self, inp):
|
||||
inp = self.preprocess_input(inp)
|
||||
output = self.forward_model(inp)
|
||||
output = self.postprocess_output(output)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def example_run(cls, inp):
|
||||
from utils.audio import save_wav
|
||||
|
||||
#set_hparams(print_hparams=False)
|
||||
infer_ins = cls(hp)
|
||||
out = infer_ins.infer_once(inp)
|
||||
os.makedirs('infer_out', exist_ok=True)
|
||||
save_wav(out, f'infer_out/{hp["text"]}.wav', hp['audio_sample_rate'])
|
||||
print(f'Save at infer_out/{hp["text"]}.wav.')
|
||||
|
||||
def asr(self, file):
|
||||
sample_rate = self.hparams['audio_sample_rate']
|
||||
audio_input, source_sample_rate = sf.read(file)
|
||||
|
||||
# Resample the wav if needed
|
||||
if sample_rate is not None and source_sample_rate != sample_rate:
|
||||
audio_input = librosa.resample(audio_input, source_sample_rate, sample_rate)
|
||||
|
||||
# pad input values and return pt tensor
|
||||
input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
|
||||
|
||||
# retrieve logits & take argmax
|
||||
logits = self.asr_model(input_values.cuda()).logits
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
# transcribe
|
||||
transcription = self.asr_processor.decode(predicted_ids[0])
|
||||
transcription = transcription.rstrip(punctuation)
|
||||
return audio_input, transcription
|
||||
if __name__ == '__main__':
|
||||
inp = {
|
||||
'text': 'here we go',
|
||||
'ref_audio': 'assets/0011_001570.wav'
|
||||
}
|
||||
GenerSpeechInfer.example_run(inp)
|
||||
104
NeuralSeq/inference/tts/base_tts_infer.py
Normal file
104
NeuralSeq/inference/tts/base_tts_infer.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from data_gen.tts.data_gen_utils import is_sil_phoneme
|
||||
from resemblyzer import VoiceEncoder
|
||||
from data_gen.tts.data_gen_utils import build_phone_encoder, build_word_encoder
|
||||
from tasks.tts.dataset_utils import FastSpeechWordDataset
|
||||
from tasks.tts.tts_utils import load_data_preprocessor
|
||||
from vocoders.hifigan import HifiGanGenerator
|
||||
import os
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
from string import punctuation
|
||||
import torch
|
||||
from utils.ckpt_utils import load_ckpt
|
||||
from utils.hparams import set_hparams
|
||||
from utils.hparams import hparams as hp
|
||||
|
||||
class BaseTTSInfer:
|
||||
def __init__(self, hparams, device=None):
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.hparams = hparams
|
||||
self.device = device
|
||||
self.data_dir = hparams['binary_data_dir']
|
||||
self.preprocessor, self.preprocess_args = load_data_preprocessor()
|
||||
self.ph_encoder, self.word_encoder = self.preprocessor.load_dict(self.data_dir)
|
||||
self.ds_cls = FastSpeechWordDataset
|
||||
self.model = self.build_model()
|
||||
self.model.eval()
|
||||
self.model.to(self.device)
|
||||
self.vocoder = self.build_vocoder()
|
||||
self.vocoder.eval()
|
||||
self.vocoder.to(self.device)
|
||||
self.asr_processor, self.asr_model = self.build_asr()
|
||||
|
||||
def build_model(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_model(self, inp):
|
||||
raise NotImplementedError
|
||||
|
||||
def build_asr(self):
|
||||
# load pretrained model
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") # facebook/wav2vec2-base-960h wav2vec2-large-960h-lv60-self
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(self.device)
|
||||
return processor, model
|
||||
|
||||
def build_vocoder(self):
|
||||
base_dir = self.hparams['vocoder_ckpt']
|
||||
config_path = f'{base_dir}/config.yaml'
|
||||
config = set_hparams(config_path, global_hparams=False)
|
||||
vocoder = HifiGanGenerator(config)
|
||||
load_ckpt(vocoder, base_dir, 'model_gen')
|
||||
return vocoder
|
||||
|
||||
def run_vocoder(self, c):
|
||||
c = c.transpose(2, 1)
|
||||
y = self.vocoder(c)[:, 0]
|
||||
return y
|
||||
|
||||
def preprocess_input(self, inp):
|
||||
raise NotImplementedError
|
||||
|
||||
def input_to_batch(self, item):
|
||||
raise NotImplementedError
|
||||
|
||||
def postprocess_output(self, output):
|
||||
return output
|
||||
|
||||
def infer_once(self, inp):
|
||||
inp = self.preprocess_input(inp)
|
||||
output = self.forward_model(inp)
|
||||
output = self.postprocess_output(output)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def example_run(cls, inp):
|
||||
from utils.audio import save_wav
|
||||
|
||||
#set_hparams(print_hparams=False)
|
||||
infer_ins = cls(hp)
|
||||
out = infer_ins.infer_once(inp)
|
||||
os.makedirs('infer_out', exist_ok=True)
|
||||
save_wav(out, f'infer_out/{hp["text"]}.wav', hp['audio_sample_rate'])
|
||||
print(f'Save at infer_out/{hp["text"]}.wav.')
|
||||
|
||||
def asr(self, file):
|
||||
sample_rate = self.hparams['audio_sample_rate']
|
||||
audio_input, source_sample_rate = sf.read(file)
|
||||
|
||||
# Resample the wav if needed
|
||||
if sample_rate is not None and source_sample_rate != sample_rate:
|
||||
audio_input = librosa.resample(audio_input, source_sample_rate, sample_rate)
|
||||
|
||||
# pad input values and return pt tensor
|
||||
input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
|
||||
|
||||
# retrieve logits & take argmax
|
||||
logits = self.asr_model(input_values.cuda()).logits
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
# transcribe
|
||||
transcription = self.asr_processor.decode(predicted_ids[0])
|
||||
transcription = transcription.rstrip(punctuation)
|
||||
return audio_input, transcription
|
||||
@@ -76,7 +76,26 @@ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False)
|
||||
pass
|
||||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
class LayerNorm_(torch.nn.LayerNorm):
|
||||
"""Layer normalization module.
|
||||
:param int nout: output dim size
|
||||
:param int dim: dimension to be normalized
|
||||
"""
|
||||
|
||||
def __init__(self, nout, dim=-1, eps=1e-5):
|
||||
"""Construct an LayerNorm object."""
|
||||
super(LayerNorm_, self).__init__(nout, eps=eps)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply layer normalization.
|
||||
:param torch.Tensor x: input tensor
|
||||
:return: layer normalized tensor
|
||||
:rtype torch.Tensor
|
||||
"""
|
||||
if self.dim == -1:
|
||||
return super(LayerNorm_, self).forward(x)
|
||||
return super(LayerNorm_, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
||||
def Linear(in_features, out_features, bias=True):
|
||||
m = nn.Linear(in_features, out_features, bias)
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
@@ -1,11 +1,12 @@
|
||||
from utils.hparams import hparams
|
||||
from modules.commons.common_layers import *
|
||||
from modules.commons.common_layers import Embedding
|
||||
from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
|
||||
EnergyPredictor, FastspeechEncoder
|
||||
from utils.cwt import cwt2f0
|
||||
from utils.hparams import hparams
|
||||
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
|
||||
|
||||
import torch.nn as nn
|
||||
from modules.commons.rel_transformer import RelTransformerEncoder, BERTRelTransformerEncoder
|
||||
FS_ENCODERS = {
|
||||
'fft': lambda hp, embed_tokens, d: FastspeechEncoder(
|
||||
embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
|
||||
@@ -27,11 +28,14 @@ class FastSpeech2(nn.Module):
|
||||
self.dec_layers = hparams['dec_layers']
|
||||
self.hidden_size = hparams['hidden_size']
|
||||
self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
|
||||
if hparams.get("use_bert", False):
|
||||
self.ph_encoder = BERTRelTransformerEncoder(len(self.dictionary), hparams['hidden_size'], hparams['hidden_size'],
|
||||
hparams['ffn_hidden_size'], hparams['num_heads'], hparams['enc_layers'],
|
||||
hparams['enc_ffn_kernel_size'], hparams['dropout'], prenet=hparams['enc_prenet'], pre_ln=hparams['enc_pre_ln'])
|
||||
else:
|
||||
self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
|
||||
self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
|
||||
self.out_dims = out_dims
|
||||
if out_dims is None:
|
||||
self.out_dims = hparams['audio_num_mel_bins']
|
||||
self.out_dims = hparams['audio_num_mel_bins'] if out_dims is None else out_dims
|
||||
self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
|
||||
|
||||
if hparams['use_spk_id']:
|
||||
@@ -46,44 +50,26 @@ class FastSpeech2(nn.Module):
|
||||
self.hidden_size,
|
||||
n_chans=predictor_hidden,
|
||||
n_layers=hparams['dur_predictor_layers'],
|
||||
dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
|
||||
dropout_rate=hparams['predictor_dropout'],
|
||||
kernel_size=hparams['dur_predictor_kernel'])
|
||||
self.length_regulator = LengthRegulator()
|
||||
if hparams['use_pitch_embed']:
|
||||
self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
|
||||
if hparams['pitch_type'] == 'cwt':
|
||||
h = hparams['cwt_hidden_size']
|
||||
cwt_out_dims = 10
|
||||
if hparams['use_uv']:
|
||||
cwt_out_dims = cwt_out_dims + 1
|
||||
self.cwt_predictor = nn.Sequential(
|
||||
nn.Linear(self.hidden_size, h),
|
||||
PitchPredictor(
|
||||
h,
|
||||
n_chans=predictor_hidden,
|
||||
n_layers=hparams['predictor_layers'],
|
||||
dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
|
||||
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
|
||||
self.cwt_stats_layers = nn.Sequential(
|
||||
nn.Linear(self.hidden_size, h), nn.ReLU(),
|
||||
nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
|
||||
)
|
||||
else:
|
||||
self.pitch_predictor = PitchPredictor(
|
||||
self.hidden_size,
|
||||
n_chans=predictor_hidden,
|
||||
n_layers=hparams['predictor_layers'],
|
||||
dropout_rate=hparams['predictor_dropout'],
|
||||
odim=2 if hparams['pitch_type'] == 'frame' else 1,
|
||||
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
||||
if hparams['use_energy_embed']:
|
||||
kernel_size=hparams['predictor_kernel'])
|
||||
if hparams.get('use_energy_embed', False):
|
||||
self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
|
||||
self.energy_predictor = EnergyPredictor(
|
||||
self.hidden_size,
|
||||
n_chans=predictor_hidden,
|
||||
n_layers=hparams['predictor_layers'],
|
||||
dropout_rate=hparams['predictor_dropout'], odim=1,
|
||||
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
||||
kernel_size=hparams['predictor_kernel'])
|
||||
|
||||
def build_embedding(self, dictionary, embed_dim):
|
||||
num_embeddings = len(dictionary)
|
||||
@@ -94,6 +80,9 @@ class FastSpeech2(nn.Module):
|
||||
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
||||
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
||||
ret = {}
|
||||
if hparams.get("use_bert", False):
|
||||
encoder_out = self.encoder(txt_tokens, bert_feats=kwargs['bert_feats'], ph2word=kwargs['ph2word'], ret=ret)
|
||||
else:
|
||||
encoder_out = self.encoder(txt_tokens) # [B, T, C]
|
||||
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
||||
|
||||
@@ -137,7 +126,7 @@ class FastSpeech2(nn.Module):
|
||||
if hparams['use_pitch_embed']:
|
||||
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
||||
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
||||
if hparams['use_energy_embed']:
|
||||
if hparams.get('use_energy_embed', False):
|
||||
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
||||
|
||||
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
||||
@@ -67,7 +67,7 @@ class DurationPredictor(torch.nn.Module):
|
||||
the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
|
||||
"""
|
||||
|
||||
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
|
||||
def __init__(self, idim, odims = 1, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
|
||||
"""Initilize duration predictor module.
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
@@ -93,14 +93,6 @@ class DurationPredictor(torch.nn.Module):
|
||||
LayerNorm(n_chans, dim=1),
|
||||
torch.nn.Dropout(dropout_rate)
|
||||
)]
|
||||
if hparams['dur_loss'] in ['mse', 'huber']:
|
||||
odims = 1
|
||||
elif hparams['dur_loss'] == 'mog':
|
||||
odims = 15
|
||||
elif hparams['dur_loss'] == 'crf':
|
||||
odims = 32
|
||||
from torchcrf import CRF
|
||||
self.crf = CRF(odims, batch_first=True)
|
||||
self.linear = torch.nn.Linear(n_chans, odims)
|
||||
|
||||
def _forward(self, xs, x_masks=None, is_inference=False):
|
||||
@@ -150,6 +142,39 @@ class DurationPredictor(torch.nn.Module):
|
||||
"""
|
||||
return self._forward(xs, x_masks, True)
|
||||
|
||||
class SyntaDurationPredictor(torch.nn.Module):
|
||||
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0):
|
||||
super(SyntaDurationPredictor, self).__init__()
|
||||
from modules.syntaspeech.syntactic_graph_encoder import GraphAuxEnc
|
||||
self.graph_encoder = GraphAuxEnc(in_dim=idim, hid_dim=idim, out_dim=idim)
|
||||
self.offset = offset
|
||||
self.conv = torch.nn.ModuleList()
|
||||
self.kernel_size = kernel_size
|
||||
for idx in range(n_layers):
|
||||
in_chans = idim if idx == 0 else n_chans
|
||||
self.conv += [torch.nn.Sequential(
|
||||
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2),
|
||||
torch.nn.ReLU(),
|
||||
LayerNorm(n_chans, dim=1),
|
||||
torch.nn.Dropout(dropout_rate)
|
||||
)]
|
||||
self.linear = nn.Sequential(torch.nn.Linear(n_chans, 1), nn.Softplus())
|
||||
|
||||
def forward(self, x, x_padding=None, ph2word=None, graph_lst=None, etypes_lst=None):
|
||||
x = x.transpose(1, -1) # (B, idim, Tmax)
|
||||
assert ph2word is not None and graph_lst is not None and etypes_lst is not None
|
||||
x_graph = self.graph_encoder(graph_lst, x, ph2word, etypes_lst)
|
||||
x = x + x_graph * 1.
|
||||
|
||||
for f in self.conv:
|
||||
x = f(x) # (B, C, Tmax)
|
||||
if x_padding is not None:
|
||||
x = x * (1 - x_padding.float())[:, None, :]
|
||||
|
||||
x = self.linear(x.transpose(1, -1)) # [B, T, C]
|
||||
x = x * (1 - x_padding.float())[:, :, None] # (B, T, C)
|
||||
x = x[..., 0] # (B, Tmax)
|
||||
return x
|
||||
|
||||
class LengthRegulator(torch.nn.Module):
|
||||
def __init__(self, pad_value=0.0):
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user