mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 11:57:58 +01:00
45 lines
1.8 KiB
Python
45 lines
1.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from transformers import *
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
# pretrained model name: (model class, model tokenizer, output dimension, token style)
|
|
MODELS = {
|
|
'prajjwal1/bert-mini': (BertModel, BertTokenizer),
|
|
}
|
|
|
|
class Text_Encoder(nn.Module):
|
|
def __init__(self, device):
|
|
super(Text_Encoder, self).__init__()
|
|
self.base_model = 'prajjwal1/bert-mini'
|
|
self.dropout = 0.1
|
|
|
|
self.tokenizer = MODELS[self.base_model][1].from_pretrained(self.base_model)
|
|
|
|
self.bert_layer = MODELS[self.base_model][0].from_pretrained(self.base_model,
|
|
add_pooling_layer=False,
|
|
hidden_dropout_prob=self.dropout,
|
|
attention_probs_dropout_prob=self.dropout,
|
|
output_hidden_states=True)
|
|
|
|
self.linear_layer = nn.Sequential(nn.Linear(256, 256), nn.ReLU(inplace=True))
|
|
|
|
self.device = device
|
|
|
|
def tokenize(self, caption):
|
|
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
tokenized = self.tokenizer(caption, add_special_tokens=False, padding=True, return_tensors='pt')
|
|
input_ids = tokenized['input_ids']
|
|
attns_mask = tokenized['attention_mask']
|
|
|
|
input_ids = input_ids.to(self.device)
|
|
attns_mask = attns_mask.to(self.device)
|
|
return input_ids, attns_mask
|
|
|
|
def forward(self, input_ids, attns_mask):
|
|
# input_ids, attns_mask = self.tokenize(caption)
|
|
output = self.bert_layer(input_ids=input_ids, attention_mask=attns_mask)[0]
|
|
cls_embed = output[:, 0, :]
|
|
text_embed = self.linear_layer(cls_embed)
|
|
|
|
return text_embed, output # text_embed: (batch, hidden_size) |