mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 11:57:58 +01:00
78 lines
3.0 KiB
Python
78 lines
3.0 KiB
Python
|
|
from data_gen.tts.emotion.params_model import *
|
|
from data_gen.tts.emotion.params_data import *
|
|
from torch.nn.utils import clip_grad_norm_
|
|
from scipy.optimize import brentq
|
|
from torch import nn
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
class EmotionEncoder(nn.Module):
|
|
def __init__(self, device, loss_device):
|
|
super().__init__()
|
|
self.loss_device = loss_device
|
|
|
|
# Network defition
|
|
self.lstm = nn.LSTM(input_size=mel_n_channels,
|
|
hidden_size=model_hidden_size,
|
|
num_layers=model_num_layers,
|
|
batch_first=True).to(device)
|
|
self.linear = nn.Linear(in_features=model_hidden_size,
|
|
out_features=model_embedding_size).to(device)
|
|
self.relu = torch.nn.ReLU().to(device)
|
|
|
|
|
|
# Cosine similarity scaling (with fixed initial parameter values)
|
|
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
|
|
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
|
|
|
|
# Loss
|
|
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
|
|
|
|
def do_gradient_ops(self):
|
|
# Gradient scale
|
|
self.similarity_weight.grad *= 0.01
|
|
self.similarity_bias.grad *= 0.01
|
|
|
|
# Gradient clipping
|
|
clip_grad_norm_(self.parameters(), 3, norm_type=2)
|
|
|
|
def forward(self, utterances, hidden_init=None):
|
|
"""
|
|
Computes the embeddings of a batch of utterance spectrograms.
|
|
|
|
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
|
(batch_size, n_frames, n_channels)
|
|
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
|
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
|
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
|
"""
|
|
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
|
# and the final cell state.
|
|
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
|
|
|
# We take only the hidden state of the last layer
|
|
embeds_raw = self.relu(self.linear(hidden[-1]))
|
|
|
|
# L2-normalize it
|
|
embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
|
|
|
return embeds
|
|
|
|
def inference(self, utterances, hidden_init=None):
|
|
"""
|
|
Computes the embeddings of a batch of utterance spectrograms.
|
|
|
|
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
|
(batch_size, n_frames, n_channels)
|
|
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
|
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
|
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
|
"""
|
|
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
|
# and the final cell state.
|
|
|
|
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
|
|
|
return hidden[-1] |