From 140b56a2e1367e556fd29cc02784f8641c9c63f4 Mon Sep 17 00:00:00 2001 From: liuhaozhe6788 <2792382045@qq.com> Date: Sun, 1 Jan 2023 16:12:48 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=98=E5=88=B6=E6=89=80=E6=9C=89=E6=96=87?= =?UTF-8?q?=E6=9C=AC=E7=89=87=E6=AE=B5=E7=9A=84=E6=B3=A8=E6=84=8F=E5=8A=9B?= =?UTF-8?q?=E5=AF=B9=E9=BD=90=E5=9B=BE=E5=92=8C=E5=90=88=E6=88=90=E7=9A=84?= =?UTF-8?q?=E6=A2=85=E5=B0=94=E9=A2=91=E8=B0=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- demo_cli.py | 9 +++++++-- vocoder/display.py | 28 ++++++++++++++++++++++++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/demo_cli.py b/demo_cli.py index 48fddf9..dfd00cd 100644 --- a/demo_cli.py +++ b/demo_cli.py @@ -18,7 +18,7 @@ from synthesizer.inference import Synthesizer from utils.argutils import print_args from utils.default_models import ensure_default_models from vocoder import inference as vocoder -from vocoder.display import save_attention +from vocoder.display import save_attention, save_spectrogram from synthesizer.utils.cleaners import english_cleaners from fixSpeed import * @@ -237,10 +237,15 @@ if __name__ == '__main__': # If you know what the attention layer alignments are, you can retrieve them here by # passing return_alignments=True specs, alignments = synthesizer.synthesize_spectrograms(texts, embeds, return_alignments=True) - save_attention(alignments.detach().cpu().numpy()[-1, :, :], "attention") + breaks = [spec.shape[1] for spec in specs] spec = np.concatenate(specs, axis=1) + + if not os.path.exists("tts_results"): + os.mkdir("tts_results") + save_attention(alignments.detach().cpu().numpy(), "tts_results/attention") + save_spectrogram(spec, "tts_results/mel") print("Created the mel spectrogram") diff --git a/vocoder/display.py b/vocoder/display.py index 3b41609..eb67bf1 100644 --- a/vocoder/display.py +++ b/vocoder/display.py @@ -86,10 +86,27 @@ def time_since(started) : def save_attention(attn, path): import matplotlib.pyplot as plt - fig = plt.figure(figsize=(12, 6)) - plt.imshow(attn.T, interpolation='nearest', aspect='auto') - fig.savefig(f'{path}.png', bbox_inches='tight') - plt.close(fig) + if attn.ndim == 2: + fig = plt.figure(figsize=(12, 6)) + plt.imshow(attn.T, interpolation='nearest', aspect='auto') + plt.xlabel("Decoder Timestep") + plt.ylabel("Encoder Timestep") + plt.title("Encoder-Decoder Alignment") + fig.savefig(f'{path}.png', bbox_inches='tight') + plt.close(fig) + elif attn.ndim == 3: + num_plots = attn.shape[0] + fig = plt.figure(figsize=(12, 6 * num_plots)) + for i, a in enumerate(attn): + plt.subplot(num_plots, 1, i+1) + plt.imshow(a.T, interpolation='nearest', aspect='auto') + plt.xlabel("Decoder Timestep") + plt.ylabel("Encoder Timestep") + plt.title("Encoder-Decoder Alignment") + fig.savefig(f'{path}.png', bbox_inches='tight') + plt.close(fig) + else: + pass def save_spectrogram(M, path, length=None): @@ -99,6 +116,9 @@ def save_spectrogram(M, path, length=None): if length : M = M[:, :length] fig = plt.figure(figsize=(12, 6)) plt.imshow(M, interpolation='nearest', aspect='auto') + plt.xlabel("Time") + plt.ylabel("Frequency") + plt.title("Generated Mel Spectrogram") fig.savefig(f'{path}.png', bbox_inches='tight') plt.close(fig)