绘制所有文本片段的注意力对齐图和合成的梅尔频谱

This commit is contained in:
liuhaozhe6788
2023-01-01 16:12:48 +08:00
parent 36e9ce9369
commit 140b56a2e1
2 changed files with 31 additions and 6 deletions

View File

@@ -18,7 +18,7 @@ from synthesizer.inference import Synthesizer
from utils.argutils import print_args
from utils.default_models import ensure_default_models
from vocoder import inference as vocoder
from vocoder.display import save_attention
from vocoder.display import save_attention, save_spectrogram
from synthesizer.utils.cleaners import english_cleaners
from fixSpeed import *
@@ -237,10 +237,15 @@ if __name__ == '__main__':
# If you know what the attention layer alignments are, you can retrieve them here by
# passing return_alignments=True
specs, alignments = synthesizer.synthesize_spectrograms(texts, embeds, return_alignments=True)
save_attention(alignments.detach().cpu().numpy()[-1, :, :], "attention")
breaks = [spec.shape[1] for spec in specs]
spec = np.concatenate(specs, axis=1)
if not os.path.exists("tts_results"):
os.mkdir("tts_results")
save_attention(alignments.detach().cpu().numpy(), "tts_results/attention")
save_spectrogram(spec, "tts_results/mel")
print("Created the mel spectrogram")

View File

@@ -86,10 +86,27 @@ def time_since(started) :
def save_attention(attn, path):
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(12, 6))
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
fig.savefig(f'{path}.png', bbox_inches='tight')
plt.close(fig)
if attn.ndim == 2:
fig = plt.figure(figsize=(12, 6))
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
plt.xlabel("Decoder Timestep")
plt.ylabel("Encoder Timestep")
plt.title("Encoder-Decoder Alignment")
fig.savefig(f'{path}.png', bbox_inches='tight')
plt.close(fig)
elif attn.ndim == 3:
num_plots = attn.shape[0]
fig = plt.figure(figsize=(12, 6 * num_plots))
for i, a in enumerate(attn):
plt.subplot(num_plots, 1, i+1)
plt.imshow(a.T, interpolation='nearest', aspect='auto')
plt.xlabel("Decoder Timestep")
plt.ylabel("Encoder Timestep")
plt.title("Encoder-Decoder Alignment")
fig.savefig(f'{path}.png', bbox_inches='tight')
plt.close(fig)
else:
pass
def save_spectrogram(M, path, length=None):
@@ -99,6 +116,9 @@ def save_spectrogram(M, path, length=None):
if length : M = M[:, :length]
fig = plt.figure(figsize=(12, 6))
plt.imshow(M, interpolation='nearest', aspect='auto')
plt.xlabel("Time")
plt.ylabel("Frequency")
plt.title("Generated Mel Spectrogram")
fig.savefig(f'{path}.png', bbox_inches='tight')
plt.close(fig)