fix wavegrad test

This commit is contained in:
erogol
2020-11-17 14:15:14 +01:00
parent a2a142dc39
commit 79ed5debcd

View File

@@ -1,5 +1,6 @@
import unittest
import numpy as np
import torch
from torch import optim
from TTS.vocoder.models.wavegrad import Wavegrad
@@ -33,7 +34,8 @@ class WavegradTrainTest(unittest.TestCase):
[1, 2, 4, 8]])
model.train()
model.to(device)
model.compute_noise_level(1000, 1e-6, 1e-2)
betas = np.linspace(1e-6, 1e-2, 1000)
model.compute_noise_level(betas)
model_ref.load_state_dict(model.state_dict())
model_ref.to(device)
count = 0