mirror of
https://github.com/coqui-ai/TTS.git
synced 2025-12-25 12:49:29 +01:00
add sequence_mask to utils.data
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional
|
||||
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.ssim import ssim
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -65,3 +66,12 @@ class StandardScaler:
|
||||
X *= self.scale_
|
||||
X += self.mean_
|
||||
return X
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
||||
# B x T_max
|
||||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
||||
Reference in New Issue
Block a user