mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 03:47:55 +01:00
26 lines
795 B
Python
26 lines
795 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .text_encoder import Text_Encoder
|
|
from .resunet_film import UNetRes_FiLM
|
|
|
|
class LASSNet(nn.Module):
|
|
def __init__(self, device='cuda'):
|
|
super(LASSNet, self).__init__()
|
|
self.text_embedder = Text_Encoder(device)
|
|
self.UNet = UNetRes_FiLM(channels=1, cond_embedding_dim=256)
|
|
|
|
def forward(self, x, caption):
|
|
# x: (Batch, 1, T, 128))
|
|
input_ids, attns_mask = self.text_embedder.tokenize(caption)
|
|
|
|
cond_vec = self.text_embedder(input_ids, attns_mask)[0]
|
|
dec_cond_vec = cond_vec
|
|
|
|
mask = self.UNet(x, cond_vec, dec_cond_vec)
|
|
mask = torch.sigmoid(mask)
|
|
return mask
|
|
|
|
def get_tokenizer(self):
|
|
return self.text_embedder.tokenizer
|