mirror of
https://github.com/coqui-ai/TTS.git
synced 2025-12-25 12:49:29 +01:00
bug fix for stop token prediciton
This commit is contained in:
@@ -476,6 +476,7 @@ class Decoder(nn.Module):
|
||||
new_memory = outputs[-1]
|
||||
self._update_memory_queue(new_memory)
|
||||
output, stop_token, attention = self.decode(inputs, t, None)
|
||||
stop_token = torch.sigmoid(stop_token.data)
|
||||
outputs += [output]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
@@ -499,12 +500,10 @@ class StopNet(nn.Module):
|
||||
super(StopNet, self).__init__()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.linear = nn.Linear(in_features, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.dropout(inputs)
|
||||
outputs = self.linear(outputs)
|
||||
outputs = self.sigmoid(outputs)
|
||||
return outputs
|
||||
|
||||
Reference in New Issue
Block a user