mirror of
https://github.com/liuhaozhe6788/voice-cloning-collab.git
synced 2026-05-18 05:04:51 +02:00
绘制所有文本片段的注意力对齐图和合成的梅尔频谱
This commit is contained in:
@@ -18,7 +18,7 @@ from synthesizer.inference import Synthesizer
|
||||
from utils.argutils import print_args
|
||||
from utils.default_models import ensure_default_models
|
||||
from vocoder import inference as vocoder
|
||||
from vocoder.display import save_attention
|
||||
from vocoder.display import save_attention, save_spectrogram
|
||||
from synthesizer.utils.cleaners import english_cleaners
|
||||
from fixSpeed import *
|
||||
|
||||
@@ -237,10 +237,15 @@ if __name__ == '__main__':
|
||||
# If you know what the attention layer alignments are, you can retrieve them here by
|
||||
# passing return_alignments=True
|
||||
specs, alignments = synthesizer.synthesize_spectrograms(texts, embeds, return_alignments=True)
|
||||
save_attention(alignments.detach().cpu().numpy()[-1, :, :], "attention")
|
||||
|
||||
|
||||
breaks = [spec.shape[1] for spec in specs]
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
|
||||
if not os.path.exists("tts_results"):
|
||||
os.mkdir("tts_results")
|
||||
save_attention(alignments.detach().cpu().numpy(), "tts_results/attention")
|
||||
save_spectrogram(spec, "tts_results/mel")
|
||||
print("Created the mel spectrogram")
|
||||
|
||||
|
||||
|
||||
@@ -86,10 +86,27 @@ def time_since(started) :
|
||||
def save_attention(attn, path):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig = plt.figure(figsize=(12, 6))
|
||||
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
|
||||
fig.savefig(f'{path}.png', bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
if attn.ndim == 2:
|
||||
fig = plt.figure(figsize=(12, 6))
|
||||
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
|
||||
plt.xlabel("Decoder Timestep")
|
||||
plt.ylabel("Encoder Timestep")
|
||||
plt.title("Encoder-Decoder Alignment")
|
||||
fig.savefig(f'{path}.png', bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
elif attn.ndim == 3:
|
||||
num_plots = attn.shape[0]
|
||||
fig = plt.figure(figsize=(12, 6 * num_plots))
|
||||
for i, a in enumerate(attn):
|
||||
plt.subplot(num_plots, 1, i+1)
|
||||
plt.imshow(a.T, interpolation='nearest', aspect='auto')
|
||||
plt.xlabel("Decoder Timestep")
|
||||
plt.ylabel("Encoder Timestep")
|
||||
plt.title("Encoder-Decoder Alignment")
|
||||
fig.savefig(f'{path}.png', bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def save_spectrogram(M, path, length=None):
|
||||
@@ -99,6 +116,9 @@ def save_spectrogram(M, path, length=None):
|
||||
if length : M = M[:, :length]
|
||||
fig = plt.figure(figsize=(12, 6))
|
||||
plt.imshow(M, interpolation='nearest', aspect='auto')
|
||||
plt.xlabel("Time")
|
||||
plt.ylabel("Frequency")
|
||||
plt.title("Generated Mel Spectrogram")
|
||||
fig.savefig(f'{path}.png', bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user