mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 03:47:55 +01:00
133 lines
4.5 KiB
Python
133 lines
4.5 KiB
Python
import math
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
|
|
|
|
|
def sort_pack_padded_sequence(input, lengths):
|
|
sorted_lengths, indices = torch.sort(lengths, descending=True)
|
|
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
|
|
inv_ix = indices.clone()
|
|
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
|
|
return tmp, inv_ix
|
|
|
|
def pad_unsort_packed_sequence(input, inv_ix):
|
|
tmp, _ = pad_packed_sequence(input, batch_first=True)
|
|
tmp = tmp[inv_ix]
|
|
return tmp
|
|
|
|
def pack_wrapper(module, attn_feats, attn_feat_lens):
|
|
packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
|
|
if isinstance(module, torch.nn.RNNBase):
|
|
return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
|
|
else:
|
|
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
|
|
|
|
def generate_length_mask(lens, max_length=None):
|
|
lens = torch.as_tensor(lens)
|
|
N = lens.size(0)
|
|
if max_length is None:
|
|
max_length = max(lens)
|
|
idxs = torch.arange(max_length).repeat(N).view(N, max_length)
|
|
idxs = idxs.to(lens.device)
|
|
mask = (idxs < lens.view(-1, 1))
|
|
return mask
|
|
|
|
def mean_with_lens(features, lens):
|
|
"""
|
|
features: [N, T, ...] (assume the second dimension represents length)
|
|
lens: [N,]
|
|
"""
|
|
lens = torch.as_tensor(lens)
|
|
if max(lens) != features.size(1):
|
|
max_length = features.size(1)
|
|
mask = generate_length_mask(lens, max_length)
|
|
else:
|
|
mask = generate_length_mask(lens)
|
|
mask = mask.to(features.device) # [N, T]
|
|
|
|
while mask.ndim < features.ndim:
|
|
mask = mask.unsqueeze(-1)
|
|
feature_mean = features * mask
|
|
feature_mean = feature_mean.sum(1)
|
|
while lens.ndim < feature_mean.ndim:
|
|
lens = lens.unsqueeze(1)
|
|
feature_mean = feature_mean / lens.to(features.device)
|
|
# feature_mean = features * mask.unsqueeze(-1)
|
|
# feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
|
|
return feature_mean
|
|
|
|
def max_with_lens(features, lens):
|
|
"""
|
|
features: [N, T, ...] (assume the second dimension represents length)
|
|
lens: [N,]
|
|
"""
|
|
lens = torch.as_tensor(lens)
|
|
mask = generate_length_mask(lens).to(features.device) # [N, T]
|
|
|
|
feature_max = features.clone()
|
|
feature_max[~mask] = float("-inf")
|
|
feature_max, _ = feature_max.max(1)
|
|
return feature_max
|
|
|
|
def repeat_tensor(x, n):
|
|
return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
|
|
|
|
def init(m, method="kaiming"):
|
|
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
|
|
if method == "kaiming":
|
|
nn.init.kaiming_uniform_(m.weight)
|
|
elif method == "xavier":
|
|
nn.init.xavier_uniform_(m.weight)
|
|
else:
|
|
raise Exception(f"initialization method {method} not supported")
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
|
nn.init.constant_(m.weight, 1)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
if method == "kaiming":
|
|
nn.init.kaiming_uniform_(m.weight)
|
|
elif method == "xavier":
|
|
nn.init.xavier_uniform_(m.weight)
|
|
else:
|
|
raise Exception(f"initialization method {method} not supported")
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Embedding):
|
|
if method == "kaiming":
|
|
nn.init.kaiming_uniform_(m.weight)
|
|
elif method == "xavier":
|
|
nn.init.xavier_uniform_(m.weight)
|
|
else:
|
|
raise Exception(f"initialization method {method} not supported")
|
|
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
|
|
def __init__(self, d_model, dropout=0.1, max_len=100):
|
|
super(PositionalEncoding, self).__init__()
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
pe = torch.zeros(max_len, d_model)
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
|
|
(-math.log(10000.0) / d_model))
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0).transpose(0, 1)
|
|
# self.register_buffer("pe", pe)
|
|
self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
|
|
|
|
def forward(self, x):
|
|
# x: [T, N, E]
|
|
x = x + self.pe[:x.size(0), :]
|
|
return self.dropout(x)
|