This commit is contained in:
lmzjms
2023-03-28 23:30:18 +08:00
parent f89ee96512
commit 03f41ec7a8
36 changed files with 3968 additions and 3 deletions

BIN
assets/2bf90e35.wav Normal file

Binary file not shown.

BIN
assets/5d67d1b9.wav Normal file

Binary file not shown.

View File

@@ -7,21 +7,45 @@ Output:<br />
Input Example : Generate an audio of a piano playing<br />
Output:<br />
![](t2a.png)<br />
Audio:<br />
<audio src="b973e878.wav" controls></audio><br />
## Text-To-Speech
Input Example : Generate a speech with text "here we go"<br />
Output:<br />
![](tts.png)<br />
Audio:<br />
<audio src="fd5cf55e.wav" controls></audio><br />
## Text-To-Sing
Input example : please generate a piece of singing voice. Text sequence is 小酒窝长睫毛AP是你最美的记号. Note sequence is C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4. Note duration sequence is 0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340.<br />
Output:<br />
![](t2s.png)<br />
Audio:<br />
<audio src="2bf90e35.wav" controls></audio><br />
## Image-To-Audio
First upload your image(.png)<br />
Input Example : Generate the audio of this image<br />
Output:<br />
![](i2a-2.png)<br />
## ASR
Audio:<br />
<audio src="5d67d1b9.wav" controls></audio><br />
## Speech Recognition
First upload your audio(.wav)<br />
Input Example : Generate the text of this audio<br />
Audio Example :<br />
<audio src="Track 4.wav" controls></audio><br />
Input Example : Generate the text of this speech<br />
Output:<br />
![](asr.png)<br />
## Audio-To-Text
First upload your audio(.wav)<br />
Audio Example :<br />
<audio src="a-group-of-sheep-are-baaing.wav" controls></audio><br />
Input Example : Please tell me the text description of this audio.<br />
Output:<br />
![](a2i.png)<br />
## Style Transfer Text-To-Speech
First upload your audio(.wav)<br />
Input Example : Speak using the voice of this audio. The text is "here we go".<br />

BIN
assets/Track 4.wav Normal file

Binary file not shown.

Binary file not shown.

BIN
assets/a2i.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

BIN
assets/b973e878.wav Normal file

Binary file not shown.

BIN
assets/fd5cf55e.wav Normal file

Binary file not shown.

BIN
assets/tts.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

View File

View File

View File

@@ -0,0 +1,3 @@
from .base_model import *
from .transformer_model import *

View File

@@ -0,0 +1,500 @@
# -*- coding: utf-8 -*-
from typing import Dict
import torch
import torch.nn as nn
from .utils import mean_with_lens, repeat_tensor
class CaptionModel(nn.Module):
"""
Encoder-decoder captioning model.
"""
pad_idx = 0
start_idx = 1
end_idx = 2
max_length = 20
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.vocab_size = decoder.vocab_size
self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
self.inference_forward_keys = ["sample_method", "max_length", "temp"]
freeze_encoder = kwargs.get("freeze_encoder", False)
if freeze_encoder:
for param in self.encoder.parameters():
param.requires_grad = False
self.check_decoder_compatibility()
def check_decoder_compatibility(self):
compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
assert isinstance(self.decoder, self.compatible_decoders), \
f"{self.decoder.__class__.__name__} is incompatible with " \
f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
@classmethod
def set_index(cls, start_idx, end_idx):
cls.start_idx = start_idx
cls.end_idx = end_idx
def forward(self, input_dict: Dict):
"""
input_dict: {
(required)
mode: train/inference,
spec,
spec_len,
fc,
attn,
attn_len,
[sample_method: greedy],
[temp: 1.0] (in case of no teacher forcing)
(optional, mode=train)
cap,
cap_len,
ss_ratio,
(optional, mode=inference)
sample_method: greedy/beam,
max_length,
temp,
beam_size (optional, sample_method=beam),
n_best (optional, sample_method=beam),
}
"""
# encoder_input_keys = ["spec", "spec_len", "fc", "attn", "attn_len"]
# encoder_input = { key: input_dict[key] for key in encoder_input_keys }
encoder_output_dict = self.encoder(input_dict)
if input_dict["mode"] == "train":
forward_dict = {
"mode": "train", "sample_method": "greedy", "temp": 1.0
}
for key in self.train_forward_keys:
forward_dict[key] = input_dict[key]
forward_dict.update(encoder_output_dict)
output = self.train_forward(forward_dict)
elif input_dict["mode"] == "inference":
forward_dict = {"mode": "inference"}
default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
for key in self.inference_forward_keys:
if key in input_dict:
forward_dict[key] = input_dict[key]
else:
forward_dict[key] = default_args[key]
if forward_dict["sample_method"] == "beam":
forward_dict["beam_size"] = input_dict.get("beam_size", 3)
forward_dict["n_best"] = input_dict.get("n_best", False)
forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
elif forward_dict["sample_method"] == "dbs":
forward_dict["beam_size"] = input_dict.get("beam_size", 6)
forward_dict["group_size"] = input_dict.get("group_size", 3)
forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
forward_dict.update(encoder_output_dict)
output = self.inference_forward(forward_dict)
else:
raise Exception("mode should be either 'train' or 'inference'")
return output
def prepare_output(self, input_dict):
output = {}
batch_size = input_dict["fc_emb"].size(0)
if input_dict["mode"] == "train":
max_length = input_dict["cap"].size(1) - 1
elif input_dict["mode"] == "inference":
max_length = input_dict["max_length"]
else:
raise Exception("mode should be either 'train' or 'inference'")
device = input_dict["fc_emb"].device
output["seq"] = torch.full((batch_size, max_length), self.end_idx,
dtype=torch.long)
output["logit"] = torch.empty(batch_size, max_length,
self.vocab_size).to(device)
output["sampled_logprob"] = torch.zeros(batch_size, max_length)
output["embed"] = torch.empty(batch_size, max_length,
self.decoder.d_model).to(device)
return output
def train_forward(self, input_dict):
if input_dict["ss_ratio"] != 1: # scheduled sampling training
input_dict["mode"] = "train"
return self.stepwise_forward(input_dict)
output = self.seq_forward(input_dict)
self.train_process(output, input_dict)
return output
def seq_forward(self, input_dict):
raise NotImplementedError
def train_process(self, output, input_dict):
pass
def inference_forward(self, input_dict):
if input_dict["sample_method"] == "beam":
return self.beam_search(input_dict)
elif input_dict["sample_method"] == "dbs":
return self.diverse_beam_search(input_dict)
return self.stepwise_forward(input_dict)
def stepwise_forward(self, input_dict):
"""Step-by-step decoding"""
output = self.prepare_output(input_dict)
max_length = output["seq"].size(1)
# start sampling
for t in range(max_length):
input_dict["t"] = t
self.decode_step(input_dict, output)
if input_dict["mode"] == "inference": # decide whether to stop when sampling
unfinished_t = output["seq"][:, t] != self.end_idx
if t == 0:
unfinished = unfinished_t
else:
unfinished *= unfinished_t
output["seq"][:, t][~unfinished] = self.end_idx
if unfinished.sum() == 0:
break
self.stepwise_process(output)
return output
def decode_step(self, input_dict, output):
"""Decoding operation of timestep t"""
decoder_input = self.prepare_decoder_input(input_dict, output)
# feed to the decoder to get logit
output_t = self.decoder(decoder_input)
logit_t = output_t["logit"]
# assert logit_t.ndim == 3
if logit_t.size(1) == 1:
logit_t = logit_t.squeeze(1)
embed_t = output_t["embed"].squeeze(1)
elif logit_t.size(1) > 1:
logit_t = logit_t[:, -1, :]
embed_t = output_t["embed"][:, -1, :]
else:
raise Exception("no logit output")
# sample the next input word and get the corresponding logit
sampled = self.sample_next_word(logit_t,
method=input_dict["sample_method"],
temp=input_dict["temp"])
output_t.update(sampled)
output_t["t"] = input_dict["t"]
output_t["logit"] = logit_t
output_t["embed"] = embed_t
self.stepwise_process_step(output, output_t)
def prepare_decoder_input(self, input_dict, output):
"""Prepare the inp ut dict for the decoder"""
raise NotImplementedError
def stepwise_process_step(self, output, output_t):
"""Postprocessing (save output values) after each timestep t"""
t = output_t["t"]
output["logit"][:, t, :] = output_t["logit"]
output["seq"][:, t] = output_t["word"]
output["sampled_logprob"][:, t] = output_t["probs"]
output["embed"][:, t, :] = output_t["embed"]
def stepwise_process(self, output):
"""Postprocessing after the whole step-by-step autoregressive decoding"""
pass
def sample_next_word(self, logit, method, temp):
"""Sample the next word, given probs output by the decoder"""
logprob = torch.log_softmax(logit, dim=1)
if method == "greedy":
sampled_logprob, word = torch.max(logprob.detach(), 1)
elif method == "gumbel":
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).to(logprob.device)
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logit, temperature):
y = logit + sample_gumbel(logit.size())
return torch.log_softmax(y / temperature, dim=-1)
_logprob = gumbel_softmax_sample(logprob, temp)
_, word = torch.max(_logprob.data, 1)
sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
else:
logprob = logprob / temp
if method.startswith("top"):
top_num = float(method[3:])
if 0 < top_num < 1: # top-p sampling
probs = torch.softmax(logit, dim=1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
_cumsum = sorted_probs.cumsum(1)
mask = _cumsum < top_num
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
sorted_probs = sorted_probs * mask.to(sorted_probs)
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
logprob.scatter_(1, sorted_indices, sorted_probs.log())
else: # top-k sampling
k = int(top_num)
tmp = torch.empty_like(logprob).fill_(float('-inf'))
topk, indices = torch.topk(logprob, k, dim=1)
tmp = tmp.scatter(1, indices, topk)
logprob = tmp
word = torch.distributions.Categorical(logits=logprob.detach()).sample()
sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
word = word.detach().long()
# sampled_logprob: [N,], word: [N,]
return {"word": word, "probs": sampled_logprob}
def beam_search(self, input_dict):
output = self.prepare_output(input_dict)
max_length = input_dict["max_length"]
beam_size = input_dict["beam_size"]
if input_dict["n_best"]:
n_best_size = input_dict["n_best_size"]
batch_size, max_length = output["seq"].size()
output["seq"] = torch.full((batch_size, n_best_size, max_length),
self.end_idx, dtype=torch.long)
temp = input_dict["temp"]
# instance by instance beam seach
for i in range(output["seq"].size(0)):
output_i = self.prepare_beamsearch_output(input_dict)
input_dict["sample_idx"] = i
for t in range(max_length):
input_dict["t"] = t
output_t = self.beamsearch_step(input_dict, output_i)
#######################################
# merge with previous beam and select the current max prob beam
#######################################
logit_t = output_t["logit"]
if logit_t.size(1) == 1:
logit_t = logit_t.squeeze(1)
elif logit_t.size(1) > 1:
logit_t = logit_t[:, -1, :]
else:
raise Exception("no logit output")
logprob_t = torch.log_softmax(logit_t, dim=1)
logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
if t == 0: # for the first step, all k seq will have the same probs
topk_logprob, topk_words = logprob_t[0].topk(
beam_size, 0, True, True)
else: # unroll and find top logprob, and their unrolled indices
topk_logprob, topk_words = logprob_t.view(-1).topk(
beam_size, 0, True, True)
topk_words = topk_words.cpu()
output_i["topk_logprob"] = topk_logprob
# output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
rounding_mode='trunc')
output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
if t == 0:
output_i["seq"] = output_i["next_word"].unsqueeze(1)
else:
output_i["seq"] = torch.cat([
output_i["seq"][output_i["prev_words_beam"]],
output_i["next_word"].unsqueeze(1)], dim=1)
# add finished beams to results
is_end = output_i["next_word"] == self.end_idx
if t == max_length - 1:
is_end.fill_(1)
for beam_idx in range(beam_size):
if is_end[beam_idx]:
final_beam = {
"seq": output_i["seq"][beam_idx].clone(),
"score": output_i["topk_logprob"][beam_idx].item()
}
final_beam["score"] = final_beam["score"] / (t + 1)
output_i["done_beams"].append(final_beam)
output_i["topk_logprob"][is_end] -= 1000
self.beamsearch_process_step(output_i, output_t)
self.beamsearch_process(output, output_i, input_dict)
return output
def prepare_beamsearch_output(self, input_dict):
beam_size = input_dict["beam_size"]
device = input_dict["fc_emb"].device
output = {
"topk_logprob": torch.zeros(beam_size).to(device),
"seq": None,
"prev_words_beam": None,
"next_word": None,
"done_beams": [],
}
return output
def beamsearch_step(self, input_dict, output_i):
decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
output_t = self.decoder(decoder_input)
output_t["t"] = input_dict["t"]
return output_t
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
raise NotImplementedError
def beamsearch_process_step(self, output_i, output_t):
pass
def beamsearch_process(self, output, output_i, input_dict):
i = input_dict["sample_idx"]
done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
if input_dict["n_best"]:
done_beams = done_beams[:input_dict["n_best_size"]]
for out_idx, done_beam in enumerate(done_beams):
seq = done_beam["seq"]
output["seq"][i][out_idx, :len(seq)] = seq
else:
seq = done_beams[0]["seq"]
output["seq"][i][:len(seq)] = seq
def diverse_beam_search(self, input_dict):
def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
local_time = t - divm
unaug_logprob = logprob.clone()
if divm > 0:
change = torch.zeros(logprob.size(-1))
for prev_choice in range(divm):
prev_decisions = seq_table[prev_choice][..., local_time]
for prev_labels in range(bdash):
change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
change = change.to(logprob.device)
logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
return logprob, unaug_logprob
output = self.prepare_output(input_dict)
group_size = input_dict["group_size"]
batch_size = output["seq"].size(0)
beam_size = input_dict["beam_size"]
bdash = beam_size // group_size
input_dict["bdash"] = bdash
diversity_lambda = input_dict["diversity_lambda"]
device = input_dict["fc_emb"].device
max_length = input_dict["max_length"]
temp = input_dict["temp"]
group_nbest = input_dict["group_nbest"]
batch_size, max_length = output["seq"].size()
if group_nbest:
output["seq"] = torch.full((batch_size, beam_size, max_length),
self.end_idx, dtype=torch.long)
else:
output["seq"] = torch.full((batch_size, group_size, max_length),
self.end_idx, dtype=torch.long)
for i in range(batch_size):
input_dict["sample_idx"] = i
seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
done_beams_table = [[] for _ in range(group_size)]
output_i = {
"prev_words_beam": [None for _ in range(group_size)],
"next_word": [None for _ in range(group_size)],
"state": [None for _ in range(group_size)]
}
for t in range(max_length + group_size - 1):
input_dict["t"] = t
for divm in range(group_size):
input_dict["divm"] = divm
if t >= divm and t <= max_length + divm - 1:
local_time = t - divm
decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
output_t = self.decoder(decoder_input)
output_t["divm"] = divm
logit_t = output_t["logit"]
if logit_t.size(1) == 1:
logit_t = logit_t.squeeze(1)
elif logit_t.size(1) > 1:
logit_t = logit_t[:, -1, :]
else:
raise Exception("no logit output")
logprob_t = torch.log_softmax(logit_t, dim=1)
logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
if local_time == 0: # for the first step, all k seq will have the same probs
topk_logprob, topk_words = logprob_t[0].topk(
bdash, 0, True, True)
else: # unroll and find top logprob, and their unrolled indices
topk_logprob, topk_words = logprob_t.view(-1).topk(
bdash, 0, True, True)
topk_words = topk_words.cpu()
logprob_table[divm] = topk_logprob
output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
if local_time > 0:
seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
seq_table[divm] = torch.cat([
seq_table[divm],
output_i["next_word"][divm].unsqueeze(-1)], -1)
is_end = seq_table[divm][:, t-divm] == self.end_idx
assert seq_table[divm].shape[-1] == t - divm + 1
if t == max_length + divm - 1:
is_end.fill_(1)
for beam_idx in range(bdash):
if is_end[beam_idx]:
final_beam = {
"seq": seq_table[divm][beam_idx].clone(),
"score": logprob_table[divm][beam_idx].item()
}
final_beam["score"] = final_beam["score"] / (t - divm + 1)
done_beams_table[divm].append(final_beam)
logprob_table[divm][is_end] -= 1000
self.dbs_process_step(output_i, output_t)
done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
if group_nbest:
done_beams = sum(done_beams_table, [])
else:
done_beams = [group_beam[0] for group_beam in done_beams_table]
for _, done_beam in enumerate(done_beams):
output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
return output
def prepare_dbs_decoder_input(self, input_dict, output_i):
raise NotImplementedError
def dbs_process_step(self, output_i, output_t):
pass
class CaptionSequenceModel(nn.Module):
def __init__(self, model, seq_output_size):
super().__init__()
self.model = model
if model.decoder.d_model != seq_output_size:
self.output_transform = nn.Linear(model.decoder.d_model, seq_output_size)
else:
self.output_transform = lambda x: x
def forward(self, input_dict):
output = self.model(input_dict)
if input_dict["mode"] == "train":
lens = input_dict["cap_len"] - 1
# seq_outputs: [N, d_model]
elif input_dict["mode"] == "inference":
if "sample_method" in input_dict and input_dict["sample_method"] == "beam":
return output
seq = output["seq"]
lens = torch.where(seq == self.model.end_idx, torch.zeros_like(seq), torch.ones_like(seq)).sum(dim=1)
else:
raise Exception("mode should be either 'train' or 'inference'")
seq_output = mean_with_lens(output["embed"], lens)
seq_output = self.output_transform(seq_output)
output["seq_output"] = seq_output
return output

View File

@@ -0,0 +1,746 @@
# -*- coding: utf-8 -*-
import math
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from .utils import generate_length_mask, init, PositionalEncoding
class BaseDecoder(nn.Module):
"""
Take word/audio embeddings and output the next word probs
Base decoder, cannot be called directly
All decoders should inherit from this class
"""
def __init__(self, emb_dim, vocab_size, fc_emb_dim,
attn_emb_dim, dropout=0.2):
super().__init__()
self.emb_dim = emb_dim
self.vocab_size = vocab_size
self.fc_emb_dim = fc_emb_dim
self.attn_emb_dim = attn_emb_dim
self.word_embedding = nn.Embedding(vocab_size, emb_dim)
self.in_dropout = nn.Dropout(dropout)
def forward(self, x):
raise NotImplementedError
def load_word_embedding(self, weight, freeze=True):
embedding = np.load(weight)
assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
# embeddings = torch.as_tensor(embeddings).float()
# self.word_embeddings.weight = nn.Parameter(embeddings)
# for para in self.word_embeddings.parameters():
# para.requires_grad = tune
self.word_embedding = nn.Embedding.from_pretrained(embedding,
freeze=freeze)
class RnnDecoder(BaseDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs):
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout,)
self.d_model = d_model
self.num_layers = kwargs.get('num_layers', 1)
self.bidirectional = kwargs.get('bidirectional', False)
self.rnn_type = kwargs.get('rnn_type', "GRU")
self.classifier = nn.Linear(
self.d_model * (self.bidirectional + 1), vocab_size)
def forward(self, x):
raise NotImplementedError
def init_hidden(self, bs, device):
num_dire = self.bidirectional + 1
n_layer = self.num_layers
hid_dim = self.d_model
if self.rnn_type == "LSTM":
return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device),
torch.zeros(num_dire * n_layer, bs, hid_dim).to(device))
else:
return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)
class RnnFcDecoder(RnnDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs):
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs)
self.model = getattr(nn, self.rnn_type)(
input_size=self.emb_dim * 2,
hidden_size=self.d_model,
batch_first=True,
num_layers=self.num_layers,
bidirectional=self.bidirectional)
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
self.apply(init)
def forward(self, input_dict):
word = input_dict["word"]
state = input_dict.get("state", None)
fc_emb = input_dict["fc_emb"]
word = word.to(fc_emb.device)
embed = self.in_dropout(self.word_embedding(word))
p_fc_emb = self.fc_proj(fc_emb)
# embed: [N, T, embed_size]
embed = torch.cat((embed, p_fc_emb), dim=-1)
out, state = self.model(embed, state)
# out: [N, T, hs], states: [num_layers * num_dire, N, hs]
logits = self.classifier(out)
output = {
"state": state,
"embeds": out,
"logits": logits
}
return output
class Seq2SeqAttention(nn.Module):
def __init__(self, hs_enc, hs_dec, attn_size):
"""
Args:
hs_enc: encoder hidden size
hs_dec: decoder hidden size
attn_size: attention vector size
"""
super(Seq2SeqAttention, self).__init__()
self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size)
self.v = nn.Parameter(torch.randn(attn_size))
self.apply(init)
def forward(self, h_dec, h_enc, src_lens):
"""
Args:
h_dec: decoder hidden (query), [N, hs_dec]
h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
src_lens: source (encoder memory) lengths, [N, ]
"""
N = h_enc.size(0)
src_max_len = h_enc.size(1)
h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
attn_input = torch.cat((h_dec, h_enc), dim=-1)
attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
score = score.masked_fill(mask == 0, -1e10)
weights = torch.softmax(score, dim=-1) # [N, src_max_len]
ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
return ctx, weights
class AttentionProj(nn.Module):
def __init__(self, hs_enc, hs_dec, embed_dim, attn_size):
self.q_proj = nn.Linear(hs_dec, embed_dim)
self.kv_proj = nn.Linear(hs_enc, embed_dim)
self.h2attn = nn.Linear(embed_dim * 2, attn_size)
self.v = nn.Parameter(torch.randn(attn_size))
self.apply(init)
def init(self, m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, h_dec, h_enc, src_lens):
"""
Args:
h_dec: decoder hidden (query), [N, hs_dec]
h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
src_lens: source (encoder memory) lengths, [N, ]
"""
h_enc = self.kv_proj(h_enc) # [N, src_max_len, embed_dim]
h_dec = self.q_proj(h_dec) # [N, embed_dim]
N = h_enc.size(0)
src_max_len = h_enc.size(1)
h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
attn_input = torch.cat((h_dec, h_enc), dim=-1)
attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
score = score.masked_fill(mask == 0, -1e10)
weights = torch.softmax(score, dim=-1) # [N, src_max_len]
ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
return ctx, weights
class BahAttnDecoder(RnnDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs):
"""
concatenate fc, attn, word to feed to the rnn
"""
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs)
attn_size = kwargs.get("attn_size", self.d_model)
self.model = getattr(nn, self.rnn_type)(
input_size=self.emb_dim * 3,
hidden_size=self.d_model,
batch_first=True,
num_layers=self.num_layers,
bidirectional=self.bidirectional)
self.attn = Seq2SeqAttention(self.attn_emb_dim,
self.d_model * (self.bidirectional + 1) * \
self.num_layers,
attn_size)
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
self.apply(init)
def forward(self, input_dict):
word = input_dict["word"]
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
fc_emb = input_dict["fc_emb"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
word = word.to(fc_emb.device)
embed = self.in_dropout(self.word_embedding(word))
# embed: [N, 1, embed_size]
if state is None:
state = self.init_hidden(word.size(0), fc_emb.device)
if self.rnn_type == "LSTM":
query = state[0].transpose(0, 1).flatten(1)
else:
query = state.transpose(0, 1).flatten(1)
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
p_fc_emb = self.fc_proj(fc_emb)
p_ctx = self.ctx_proj(c)
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)),
dim=-1)
out, state = self.model(rnn_input, state)
output = {
"state": state,
"embed": out,
"logit": self.classifier(out),
"attn_weight": attn_weight
}
return output
class BahAttnDecoder2(RnnDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs):
"""
add fc, attn, word together to feed to the rnn
"""
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs)
attn_size = kwargs.get("attn_size", self.d_model)
self.model = getattr(nn, self.rnn_type)(
input_size=self.emb_dim,
hidden_size=self.d_model,
batch_first=True,
num_layers=self.num_layers,
bidirectional=self.bidirectional)
self.attn = Seq2SeqAttention(self.emb_dim,
self.d_model * (self.bidirectional + 1) * \
self.num_layers,
attn_size)
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
self.attn_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
self.apply(partial(init, method="xavier"))
def forward(self, input_dict):
word = input_dict["word"]
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
fc_emb = input_dict["fc_emb"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
word = word.to(fc_emb.device)
embed = self.in_dropout(self.word_embedding(word))
p_attn_emb = self.attn_proj(attn_emb)
# embed: [N, 1, embed_size]
if state is None:
state = self.init_hidden(word.size(0), fc_emb.device)
if self.rnn_type == "LSTM":
query = state[0].transpose(0, 1).flatten(1)
else:
query = state.transpose(0, 1).flatten(1)
c, attn_weight = self.attn(query, p_attn_emb, attn_emb_len)
p_fc_emb = self.fc_proj(fc_emb)
rnn_input = embed + c.unsqueeze(1) + p_fc_emb.unsqueeze(1)
out, state = self.model(rnn_input, state)
output = {
"state": state,
"embed": out,
"logit": self.classifier(out),
"attn_weight": attn_weight
}
return output
class ConditionalBahAttnDecoder(RnnDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs):
"""
concatenate fc, attn, word to feed to the rnn
"""
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs)
attn_size = kwargs.get("attn_size", self.d_model)
self.model = getattr(nn, self.rnn_type)(
input_size=self.emb_dim * 3,
hidden_size=self.d_model,
batch_first=True,
num_layers=self.num_layers,
bidirectional=self.bidirectional)
self.attn = Seq2SeqAttention(self.attn_emb_dim,
self.d_model * (self.bidirectional + 1) * \
self.num_layers,
attn_size)
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
self.condition_embedding = nn.Embedding(2, emb_dim)
self.apply(init)
def forward(self, input_dict):
word = input_dict["word"]
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
fc_emb = input_dict["fc_emb"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
condition = input_dict["condition"]
word = word.to(fc_emb.device)
embed = self.in_dropout(self.word_embedding(word))
condition = torch.as_tensor([[1 - c, c] for c in condition]).to(fc_emb.device)
condition_emb = torch.matmul(condition, self.condition_embedding.weight)
# condition_embs: [N, emb_dim]
# embed: [N, 1, embed_size]
if state is None:
state = self.init_hidden(word.size(0), fc_emb.device)
if self.rnn_type == "LSTM":
query = state[0].transpose(0, 1).flatten(1)
else:
query = state.transpose(0, 1).flatten(1)
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
p_ctx = self.ctx_proj(c)
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), condition_emb.unsqueeze(1)),
dim=-1)
out, state = self.model(rnn_input, state)
output = {
"state": state,
"embed": out,
"logit": self.classifier(out),
"attn_weight": attn_weight
}
return output
class StructBahAttnDecoder(RnnDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, struct_vocab_size,
attn_emb_dim, dropout, d_model, **kwargs):
"""
concatenate fc, attn, word to feed to the rnn
"""
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs)
attn_size = kwargs.get("attn_size", self.d_model)
self.model = getattr(nn, self.rnn_type)(
input_size=self.emb_dim * 3,
hidden_size=self.d_model,
batch_first=True,
num_layers=self.num_layers,
bidirectional=self.bidirectional)
self.attn = Seq2SeqAttention(self.attn_emb_dim,
self.d_model * (self.bidirectional + 1) * \
self.num_layers,
attn_size)
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
self.struct_embedding = nn.Embedding(struct_vocab_size, emb_dim)
self.apply(init)
def forward(self, input_dict):
word = input_dict["word"]
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
fc_emb = input_dict["fc_emb"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
structure = input_dict["structure"]
word = word.to(fc_emb.device)
embed = self.in_dropout(self.word_embedding(word))
struct_emb = self.struct_embedding(structure)
# struct_embs: [N, emb_dim]
# embed: [N, 1, embed_size]
if state is None:
state = self.init_hidden(word.size(0), fc_emb.device)
if self.rnn_type == "LSTM":
query = state[0].transpose(0, 1).flatten(1)
else:
query = state.transpose(0, 1).flatten(1)
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
p_ctx = self.ctx_proj(c)
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), struct_emb.unsqueeze(1)), dim=-1)
out, state = self.model(rnn_input, state)
output = {
"state": state,
"embed": out,
"logit": self.classifier(out),
"attn_weight": attn_weight
}
return output
class StyleBahAttnDecoder(RnnDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs):
"""
concatenate fc, attn, word to feed to the rnn
"""
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs)
attn_size = kwargs.get("attn_size", self.d_model)
self.model = getattr(nn, self.rnn_type)(
input_size=self.emb_dim * 3,
hidden_size=self.d_model,
batch_first=True,
num_layers=self.num_layers,
bidirectional=self.bidirectional)
self.attn = Seq2SeqAttention(self.attn_emb_dim,
self.d_model * (self.bidirectional + 1) * \
self.num_layers,
attn_size)
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
self.apply(init)
def forward(self, input_dict):
word = input_dict["word"]
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
fc_emb = input_dict["fc_emb"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
style = input_dict["style"]
word = word.to(fc_emb.device)
embed = self.in_dropout(self.word_embedding(word))
# embed: [N, 1, embed_size]
if state is None:
state = self.init_hidden(word.size(0), fc_emb.device)
if self.rnn_type == "LSTM":
query = state[0].transpose(0, 1).flatten(1)
else:
query = state.transpose(0, 1).flatten(1)
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
p_ctx = self.ctx_proj(c)
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), style.unsqueeze(1)),
dim=-1)
out, state = self.model(rnn_input, state)
output = {
"state": state,
"embed": out,
"logit": self.classifier(out),
"attn_weight": attn_weight
}
return output
class BahAttnDecoder3(RnnDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs):
"""
concatenate fc, attn, word to feed to the rnn
"""
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs)
attn_size = kwargs.get("attn_size", self.d_model)
self.model = getattr(nn, self.rnn_type)(
input_size=self.emb_dim + attn_emb_dim,
hidden_size=self.d_model,
batch_first=True,
num_layers=self.num_layers,
bidirectional=self.bidirectional)
self.attn = Seq2SeqAttention(self.attn_emb_dim,
self.d_model * (self.bidirectional + 1) * \
self.num_layers,
attn_size)
self.ctx_proj = lambda x: x
self.apply(init)
def forward(self, input_dict):
word = input_dict["word"]
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
fc_emb = input_dict["fc_emb"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
if word.size(-1) == self.fc_emb_dim: # fc_emb
embed = word.unsqueeze(1)
elif word.size(-1) == 1: # word
word = word.to(fc_emb.device)
embed = self.in_dropout(self.word_embedding(word))
else:
raise Exception(f"problem with word input size {word.size()}")
# embed: [N, 1, embed_size]
if state is None:
state = self.init_hidden(word.size(0), fc_emb.device)
if self.rnn_type == "LSTM":
query = state[0].transpose(0, 1).flatten(1)
else:
query = state.transpose(0, 1).flatten(1)
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
p_ctx = self.ctx_proj(c)
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1)), dim=-1)
out, state = self.model(rnn_input, state)
output = {
"state": state,
"embed": out,
"logit": self.classifier(out),
"attn_weight": attn_weight
}
return output
class SpecificityBahAttnDecoder(RnnDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs):
"""
concatenate fc, attn, word to feed to the rnn
"""
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs)
attn_size = kwargs.get("attn_size", self.d_model)
self.model = getattr(nn, self.rnn_type)(
input_size=self.emb_dim + attn_emb_dim + 1,
hidden_size=self.d_model,
batch_first=True,
num_layers=self.num_layers,
bidirectional=self.bidirectional)
self.attn = Seq2SeqAttention(self.attn_emb_dim,
self.d_model * (self.bidirectional + 1) * \
self.num_layers,
attn_size)
self.ctx_proj = lambda x: x
self.apply(init)
def forward(self, input_dict):
word = input_dict["word"]
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
fc_emb = input_dict["fc_emb"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
condition = input_dict["condition"] # [N,]
word = word.to(fc_emb.device)
embed = self.in_dropout(self.word_embedding(word))
# embed: [N, 1, embed_size]
if state is None:
state = self.init_hidden(word.size(0), fc_emb.device)
if self.rnn_type == "LSTM":
query = state[0].transpose(0, 1).flatten(1)
else:
query = state.transpose(0, 1).flatten(1)
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
p_ctx = self.ctx_proj(c)
rnn_input = torch.cat(
(embed, p_ctx.unsqueeze(1), condition.reshape(-1, 1, 1)),
dim=-1)
out, state = self.model(rnn_input, state)
output = {
"state": state,
"embed": out,
"logit": self.classifier(out),
"attn_weight": attn_weight
}
return output
class TransformerDecoder(BaseDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, **kwargs):
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout=dropout,)
self.d_model = emb_dim
self.nhead = kwargs.get("nhead", self.d_model // 64)
self.nlayers = kwargs.get("nlayers", 2)
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
self.pos_encoder = PositionalEncoding(self.d_model, dropout)
layer = nn.TransformerDecoderLayer(d_model=self.d_model,
nhead=self.nhead,
dim_feedforward=self.dim_feedforward,
dropout=dropout)
self.model = nn.TransformerDecoder(layer, self.nlayers)
self.classifier = nn.Linear(self.d_model, vocab_size)
self.attn_proj = nn.Sequential(
nn.Linear(self.attn_emb_dim, self.d_model),
nn.ReLU(),
nn.Dropout(dropout),
nn.LayerNorm(self.d_model)
)
# self.attn_proj = lambda x: x
self.init_params()
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def generate_square_subsequent_mask(self, max_length):
mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, input_dict):
word = input_dict["word"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
cap_padding_mask = input_dict["cap_padding_mask"]
p_attn_emb = self.attn_proj(attn_emb)
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
word = word.to(attn_emb.device)
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
embed = embed.transpose(0, 1) # [T, N, emb_dim]
embed = self.pos_encoder(embed)
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
tgt_key_padding_mask=cap_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
output = output.transpose(0, 1)
output = {
"embed": output,
"logit": self.classifier(output),
}
return output
class EventTransformerDecoder(TransformerDecoder):
def forward(self, input_dict):
word = input_dict["word"] # index of word embeddings
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
cap_padding_mask = input_dict["cap_padding_mask"]
event_emb = input_dict["event"] # [N, emb_dim]
p_attn_emb = self.attn_proj(attn_emb)
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
word = word.to(attn_emb.device)
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
embed = embed.transpose(0, 1) # [T, N, emb_dim]
embed += event_emb
embed = self.pos_encoder(embed)
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
tgt_key_padding_mask=cap_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
output = output.transpose(0, 1)
output = {
"embed": output,
"logit": self.classifier(output),
}
return output
class KeywordProbTransformerDecoder(TransformerDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, keyword_classes_num, **kwargs):
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, **kwargs)
self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model)
self.word_keyword_norm = nn.LayerNorm(self.d_model)
def forward(self, input_dict):
word = input_dict["word"] # index of word embeddings
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
cap_padding_mask = input_dict["cap_padding_mask"]
keyword = input_dict["keyword"] # [N, keyword_classes_num]
p_attn_emb = self.attn_proj(attn_emb)
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
word = word.to(attn_emb.device)
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
embed = embed.transpose(0, 1) # [T, N, emb_dim]
embed += self.keyword_proj(keyword)
embed = self.word_keyword_norm(embed)
embed = self.pos_encoder(embed)
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
tgt_key_padding_mask=cap_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
output = output.transpose(0, 1)
output = {
"embed": output,
"logit": self.classifier(output),
}
return output

View File

@@ -0,0 +1,686 @@
# -*- coding: utf-8 -*-
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio import transforms
from torchlibrosa.augmentation import SpecAugmentation
from .utils import mean_with_lens, max_with_lens, \
init, pack_wrapper, generate_length_mask, PositionalEncoding
def init_layer(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.)
def init_bn(bn):
"""Initialize a Batchnorm layer. """
bn.bias.data.fill_(0.)
bn.weight.data.fill_(1.)
class BaseEncoder(nn.Module):
"""
Encode the given audio into embedding
Base encoder class, cannot be called directly
All encoders should inherit from this class
"""
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
super(BaseEncoder, self).__init__()
self.spec_dim = spec_dim
self.fc_feat_dim = fc_feat_dim
self.attn_feat_dim = attn_feat_dim
def forward(self, x):
#########################
# an encoder first encodes audio feature into embedding, obtaining
# `encoded`: {
# fc_embs: [N, fc_emb_dim],
# attn_embs: [N, attn_max_len, attn_emb_dim],
# attn_emb_lens: [N,]
# }
#########################
raise NotImplementedError
class Block2D(nn.Module):
def __init__(self, cin, cout, kernel_size=3, padding=1):
super().__init__()
self.block = nn.Sequential(
nn.BatchNorm2d(cin),
nn.Conv2d(cin,
cout,
kernel_size=kernel_size,
padding=padding,
bias=False),
nn.LeakyReLU(inplace=True, negative_slope=0.1))
def forward(self, x):
return self.block(x)
class LinearSoftPool(nn.Module):
"""LinearSoftPool
Linear softmax, takes logits and returns a probability, near to the actual maximum value.
Taken from the paper:
A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling
https://arxiv.org/abs/1810.09050
"""
def __init__(self, pooldim=1):
super().__init__()
self.pooldim = pooldim
def forward(self, logits, time_decision):
return (time_decision**2).sum(self.pooldim) / time_decision.sum(
self.pooldim)
class MeanPool(nn.Module):
def __init__(self, pooldim=1):
super().__init__()
self.pooldim = pooldim
def forward(self, logits, decision):
return torch.mean(decision, dim=self.pooldim)
class AttentionPool(nn.Module):
"""docstring for AttentionPool"""
def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs):
super().__init__()
self.inputdim = inputdim
self.outputdim = outputdim
self.pooldim = pooldim
self.transform = nn.Linear(inputdim, outputdim)
self.activ = nn.Softmax(dim=self.pooldim)
self.eps = 1e-7
def forward(self, logits, decision):
# Input is (B, T, D)
# B, T, D
w = self.activ(torch.clamp(self.transform(logits), -15, 15))
detect = (decision * w).sum(
self.pooldim) / (w.sum(self.pooldim) + self.eps)
# B, T, D
return detect
class MMPool(nn.Module):
def __init__(self, dims):
super().__init__()
self.avgpool = nn.AvgPool2d(dims)
self.maxpool = nn.MaxPool2d(dims)
def forward(self, x):
return self.avgpool(x) + self.maxpool(x)
def parse_poolingfunction(poolingfunction_name='mean', **kwargs):
"""parse_poolingfunction
A heler function to parse any temporal pooling
Pooling is done on dimension 1
:param poolingfunction_name:
:param **kwargs:
"""
poolingfunction_name = poolingfunction_name.lower()
if poolingfunction_name == 'mean':
return MeanPool(pooldim=1)
elif poolingfunction_name == 'linear':
return LinearSoftPool(pooldim=1)
elif poolingfunction_name == 'attention':
return AttentionPool(inputdim=kwargs['inputdim'],
outputdim=kwargs['outputdim'])
def embedding_pooling(x, lens, pooling="mean"):
if pooling == "max":
fc_embs = max_with_lens(x, lens)
elif pooling == "mean":
fc_embs = mean_with_lens(x, lens)
elif pooling == "mean+max":
x_mean = mean_with_lens(x, lens)
x_max = max_with_lens(x, lens)
fc_embs = x_mean + x_max
elif pooling == "last":
indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
# indices: [N, 1, hidden]
fc_embs = torch.gather(x, 1, indices).squeeze(1)
else:
raise Exception(f"pooling method {pooling} not support")
return fc_embs
class Cdur5Encoder(BaseEncoder):
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
self.pooling = pooling
self.features = nn.Sequential(
Block2D(1, 32),
nn.LPPool2d(4, (2, 4)),
Block2D(32, 128),
Block2D(128, 128),
nn.LPPool2d(4, (2, 4)),
Block2D(128, 128),
Block2D(128, 128),
nn.LPPool2d(4, (1, 4)),
nn.Dropout(0.3),
)
with torch.no_grad():
rnn_input_dim = self.features(
torch.randn(1, 1, 500, spec_dim)).shape
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1]
self.gru = nn.GRU(rnn_input_dim,
128,
bidirectional=True,
batch_first=True)
self.apply(init)
def forward(self, input_dict):
x = input_dict["spec"]
lens = input_dict["spec_len"]
if "upsample" not in input_dict:
input_dict["upsample"] = False
lens = torch.as_tensor(copy.deepcopy(lens))
N, T, _ = x.shape
x = x.unsqueeze(1)
x = self.features(x)
x = x.transpose(1, 2).contiguous().flatten(-2)
x, _ = self.gru(x)
if input_dict["upsample"]:
x = nn.functional.interpolate(
x.transpose(1, 2),
T,
mode='linear',
align_corners=False).transpose(1, 2)
else:
lens //= 4
attn_emb = x
fc_emb = embedding_pooling(x, lens, self.pooling)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": lens
}
def conv_conv_block(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel,
out_channel,
kernel_size=3,
bias=False,
padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(True),
nn.Conv2d(out_channel,
out_channel,
kernel_size=3,
bias=False,
padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(True)
)
class Cdur8Encoder(BaseEncoder):
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
self.pooling = pooling
self.features = nn.Sequential(
conv_conv_block(1, 64),
MMPool((2, 2)),
nn.Dropout(0.2, True),
conv_conv_block(64, 128),
MMPool((2, 2)),
nn.Dropout(0.2, True),
conv_conv_block(128, 256),
MMPool((1, 2)),
nn.Dropout(0.2, True),
conv_conv_block(256, 512),
MMPool((1, 2)),
nn.Dropout(0.2, True),
nn.AdaptiveAvgPool2d((None, 1)),
)
self.init_bn = nn.BatchNorm2d(spec_dim)
self.embedding = nn.Linear(512, 512)
self.gru = nn.GRU(512, 256, bidirectional=True, batch_first=True)
self.apply(init)
def forward(self, input_dict):
x = input_dict["spec"]
lens = input_dict["spec_len"]
lens = torch.as_tensor(copy.deepcopy(lens))
x = x.unsqueeze(1) # B x 1 x T x D
x = x.transpose(1, 3)
x = self.init_bn(x)
x = x.transpose(1, 3)
x = self.features(x)
x = x.transpose(1, 2).contiguous().flatten(-2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.embedding(x))
x, _ = self.gru(x)
attn_emb = x
lens //= 4
fc_emb = embedding_pooling(x, lens, self.pooling)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": lens
}
class Cnn10Encoder(BaseEncoder):
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
self.features = nn.Sequential(
conv_conv_block(1, 64),
nn.AvgPool2d((2, 2)),
nn.Dropout(0.2, True),
conv_conv_block(64, 128),
nn.AvgPool2d((2, 2)),
nn.Dropout(0.2, True),
conv_conv_block(128, 256),
nn.AvgPool2d((2, 2)),
nn.Dropout(0.2, True),
conv_conv_block(256, 512),
nn.AvgPool2d((2, 2)),
nn.Dropout(0.2, True),
nn.AdaptiveAvgPool2d((None, 1)),
)
self.init_bn = nn.BatchNorm2d(spec_dim)
self.embedding = nn.Linear(512, 512)
self.apply(init)
def forward(self, input_dict):
x = input_dict["spec"]
lens = input_dict["spec_len"]
lens = torch.as_tensor(copy.deepcopy(lens))
x = x.unsqueeze(1) # [N, 1, T, D]
x = x.transpose(1, 3)
x = self.init_bn(x)
x = x.transpose(1, 3)
x = self.features(x) # [N, 512, T/16, 1]
x = x.transpose(1, 2).contiguous().flatten(-2) # [N, T/16, 512]
attn_emb = x
lens //= 16
fc_emb = embedding_pooling(x, lens, "mean+max")
fc_emb = F.dropout(fc_emb, p=0.5, training=self.training)
fc_emb = self.embedding(fc_emb)
fc_emb = F.relu_(fc_emb)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": lens
}
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.init_weight()
def init_weight(self):
init_layer(self.conv1)
init_layer(self.conv2)
init_bn(self.bn1)
init_bn(self.bn2)
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
x = F.relu_(self.bn2(self.conv2(x)))
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception('Incorrect argument!')
return x
class Cnn14Encoder(nn.Module):
def __init__(self, sample_rate=32000):
super().__init__()
sr_to_fmax = {
32000: 14000,
16000: 8000
}
# Logmel spectrogram extractor
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=32 * sample_rate // 1000,
win_length=32 * sample_rate // 1000,
hop_length=10 * sample_rate // 1000,
f_min=50,
f_max=sr_to_fmax[sample_rate],
n_mels=64,
norm="slaney",
mel_scale="slaney"
)
self.hop_length = 10 * sample_rate // 1000
self.db_transform = transforms.AmplitudeToDB()
# Spec augmenter
self.spec_augmenter = SpecAugmentation(time_drop_width=64,
time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2)
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
self.downsample_ratio = 32
self.fc1 = nn.Linear(2048, 2048, bias=True)
self.init_weight()
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
def load_pretrained(self, pretrained):
checkpoint = torch.load(pretrained, map_location="cpu")
if "model" in checkpoint:
state_keys = checkpoint["model"].keys()
backbone = False
for key in state_keys:
if key.startswith("backbone."):
backbone = True
break
if backbone: # COLA
state_dict = {}
for key, value in checkpoint["model"].items():
if key.startswith("backbone."):
model_key = key.replace("backbone.", "")
state_dict[model_key] = value
else: # PANNs
state_dict = checkpoint["model"]
elif "state_dict" in checkpoint: # CLAP
state_dict = checkpoint["state_dict"]
state_dict_keys = list(filter(
lambda x: "audio_encoder" in x, state_dict.keys()))
state_dict = {
key.replace('audio_encoder.', ''): state_dict[key]
for key in state_dict_keys
}
else:
raise Exception("Unkown checkpoint format")
model_dict = self.state_dict()
pretrained_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (
model_dict[k].shape == v.shape)
}
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict, strict=True)
def forward(self, input_dict):
"""
Input: (batch_size, n_samples)"""
waveform = input_dict["wav"]
wave_length = input_dict["wav_len"]
specaug = input_dict["specaug"]
x = self.melspec_extractor(waveform)
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
x = x.transpose(1, 2)
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
# SpecAugment
if self.training and specaug:
x = self.spec_augmenter(x)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
attn_emb = x.transpose(1, 2)
wave_length = torch.as_tensor(wave_length)
feat_length = torch.div(wave_length, self.hop_length,
rounding_mode="floor") + 1
feat_length = torch.div(feat_length, self.downsample_ratio,
rounding_mode="floor")
x_max = max_with_lens(attn_emb, feat_length)
x_mean = mean_with_lens(attn_emb, feat_length)
x = x_max + x_mean
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
fc_emb = F.dropout(x, p=0.5, training=self.training)
output_dict = {
'fc_emb': fc_emb,
'attn_emb': attn_emb,
'attn_emb_len': feat_length
}
return output_dict
class RnnEncoder(BaseEncoder):
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim,
pooling="mean", **kwargs):
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
self.pooling = pooling
self.hidden_size = kwargs.get('hidden_size', 512)
self.bidirectional = kwargs.get('bidirectional', False)
self.num_layers = kwargs.get('num_layers', 1)
self.dropout = kwargs.get('dropout', 0.2)
self.rnn_type = kwargs.get('rnn_type', "GRU")
self.in_bn = kwargs.get('in_bn', False)
self.embed_dim = self.hidden_size * (self.bidirectional + 1)
self.network = getattr(nn, self.rnn_type)(
attn_feat_dim,
self.hidden_size,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
dropout=self.dropout,
batch_first=True)
if self.in_bn:
self.bn = nn.BatchNorm1d(self.embed_dim)
self.apply(init)
def forward(self, input_dict):
x = input_dict["attn"]
lens = input_dict["attn_len"]
lens = torch.as_tensor(lens)
# x: [N, T, E]
if self.in_bn:
x = pack_wrapper(self.bn, x, lens)
out = pack_wrapper(self.network, x, lens)
# out: [N, T, hidden]
attn_emb = out
fc_emb = embedding_pooling(out, lens, self.pooling)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": lens
}
class Cnn14RnnEncoder(nn.Module):
def __init__(self, sample_rate=32000, pretrained=None,
freeze_cnn=False, freeze_cnn_bn=False,
pooling="mean", **kwargs):
super().__init__()
self.cnn = Cnn14Encoder(sample_rate)
self.rnn = RnnEncoder(64, 2048, 2048, pooling, **kwargs)
if pretrained is not None:
self.cnn.load_pretrained(pretrained)
if freeze_cnn:
assert pretrained is not None, "cnn is not pretrained but frozen"
for param in self.cnn.parameters():
param.requires_grad = False
self.freeze_cnn_bn = freeze_cnn_bn
def train(self, mode):
super().train(mode=mode)
if self.freeze_cnn_bn:
def bn_eval(module):
class_name = module.__class__.__name__
if class_name.find("BatchNorm") != -1:
module.eval()
self.cnn.apply(bn_eval)
return self
def forward(self, input_dict):
output_dict = self.cnn(input_dict)
output_dict["attn"] = output_dict["attn_emb"]
output_dict["attn_len"] = output_dict["attn_emb_len"]
del output_dict["attn_emb"], output_dict["attn_emb_len"]
output_dict = self.rnn(output_dict)
return output_dict
class TransformerEncoder(BaseEncoder):
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, d_model, **kwargs):
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
self.d_model = d_model
dropout = kwargs.get("dropout", 0.2)
self.nhead = kwargs.get("nhead", self.d_model // 64)
self.nlayers = kwargs.get("nlayers", 2)
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
self.attn_proj = nn.Sequential(
nn.Linear(attn_feat_dim, self.d_model),
nn.ReLU(),
nn.Dropout(dropout),
nn.LayerNorm(self.d_model)
)
layer = nn.TransformerEncoderLayer(d_model=self.d_model,
nhead=self.nhead,
dim_feedforward=self.dim_feedforward,
dropout=dropout)
self.model = nn.TransformerEncoder(layer, self.nlayers)
self.cls_token = nn.Parameter(torch.zeros(d_model))
self.init_params()
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, input_dict):
attn_feat = input_dict["attn"]
attn_feat_len = input_dict["attn_len"]
attn_feat_len = torch.as_tensor(attn_feat_len)
attn_feat = self.attn_proj(attn_feat) # [bs, T, d_model]
cls_emb = self.cls_token.reshape(1, 1, self.d_model).repeat(
attn_feat.size(0), 1, 1)
attn_feat = torch.cat((cls_emb, attn_feat), dim=1)
attn_feat = attn_feat.transpose(0, 1)
attn_feat_len += 1
src_key_padding_mask = ~generate_length_mask(
attn_feat_len, attn_feat.size(0)).to(attn_feat.device)
output = self.model(attn_feat, src_key_padding_mask=src_key_padding_mask)
attn_emb = output.transpose(0, 1)
fc_emb = attn_emb[:, 0]
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": attn_feat_len
}
class Cnn14TransformerEncoder(nn.Module):
def __init__(self, sample_rate=32000, pretrained=None,
freeze_cnn=False, freeze_cnn_bn=False,
d_model="mean", **kwargs):
super().__init__()
self.cnn = Cnn14Encoder(sample_rate)
self.trm = TransformerEncoder(64, 2048, 2048, d_model, **kwargs)
if pretrained is not None:
self.cnn.load_pretrained(pretrained)
if freeze_cnn:
assert pretrained is not None, "cnn is not pretrained but frozen"
for param in self.cnn.parameters():
param.requires_grad = False
self.freeze_cnn_bn = freeze_cnn_bn
def train(self, mode):
super().train(mode=mode)
if self.freeze_cnn_bn:
def bn_eval(module):
class_name = module.__class__.__name__
if class_name.find("BatchNorm") != -1:
module.eval()
self.cnn.apply(bn_eval)
return self
def forward(self, input_dict):
output_dict = self.cnn(input_dict)
output_dict["attn"] = output_dict["attn_emb"]
output_dict["attn_len"] = output_dict["attn_emb_len"]
del output_dict["attn_emb"], output_dict["attn_emb_len"]
output_dict = self.trm(output_dict)
return output_dict

View File

@@ -0,0 +1,265 @@
# -*- coding: utf-8 -*-
import random
import torch
import torch.nn as nn
from .base_model import CaptionModel
from .utils import repeat_tensor
import audio_to_text.captioning.models.decoder
class TransformerModel(CaptionModel):
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
if not hasattr(self, "compatible_decoders"):
self.compatible_decoders = (
audio_to_text.captioning.models.decoder.TransformerDecoder,
)
super().__init__(encoder, decoder, **kwargs)
def seq_forward(self, input_dict):
cap = input_dict["cap"]
cap_padding_mask = (cap == self.pad_idx).to(cap.device)
cap_padding_mask = cap_padding_mask[:, :-1]
output = self.decoder(
{
"word": cap[:, :-1],
"attn_emb": input_dict["attn_emb"],
"attn_emb_len": input_dict["attn_emb_len"],
"cap_padding_mask": cap_padding_mask
}
)
return output
def prepare_decoder_input(self, input_dict, output):
decoder_input = {
"attn_emb": input_dict["attn_emb"],
"attn_emb_len": input_dict["attn_emb_len"]
}
t = input_dict["t"]
###############
# determine input word
################
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
word = input_dict["cap"][:, :t+1]
else:
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
if t == 0:
word = start_word
else:
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
# word: [N, T]
decoder_input["word"] = word
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
decoder_input["cap_padding_mask"] = cap_padding_mask
return decoder_input
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
decoder_input = {}
t = input_dict["t"]
i = input_dict["sample_idx"]
beam_size = input_dict["beam_size"]
###############
# prepare attn embeds
################
if t == 0:
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
output_i["attn_emb"] = attn_emb
output_i["attn_emb_len"] = attn_emb_len
decoder_input["attn_emb"] = output_i["attn_emb"]
decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
###############
# determine input word
################
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
if t == 0:
word = start_word
else:
word = torch.cat((start_word, output_i["seq"]), dim=-1)
decoder_input["word"] = word
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
decoder_input["cap_padding_mask"] = cap_padding_mask
return decoder_input
class M2TransformerModel(CaptionModel):
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
if not hasattr(self, "compatible_decoders"):
self.compatible_decoders = (
captioning.models.decoder.M2TransformerDecoder,
)
super().__init__(encoder, decoder, **kwargs)
self.check_encoder_compatibility()
def check_encoder_compatibility(self):
assert isinstance(self.encoder, captioning.models.encoder.M2TransformerEncoder), \
f"only M2TransformerModel is compatible with {self.__class__.__name__}"
def seq_forward(self, input_dict):
cap = input_dict["cap"]
output = self.decoder(
{
"word": cap[:, :-1],
"attn_emb": input_dict["attn_emb"],
"attn_emb_mask": input_dict["attn_emb_mask"],
}
)
return output
def prepare_decoder_input(self, input_dict, output):
decoder_input = {
"attn_emb": input_dict["attn_emb"],
"attn_emb_mask": input_dict["attn_emb_mask"]
}
t = input_dict["t"]
###############
# determine input word
################
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
word = input_dict["cap"][:, :t+1]
else:
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
if t == 0:
word = start_word
else:
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
# word: [N, T]
decoder_input["word"] = word
return decoder_input
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
decoder_input = {}
t = input_dict["t"]
i = input_dict["sample_idx"]
beam_size = input_dict["beam_size"]
###############
# prepare attn embeds
################
if t == 0:
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size)
output_i["attn_emb"] = attn_emb
output_i["attn_emb_mask"] = attn_emb_mask
decoder_input["attn_emb"] = output_i["attn_emb"]
decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"]
###############
# determine input word
################
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
if t == 0:
word = start_word
else:
word = torch.cat((start_word, output_i["seq"]), dim=-1)
decoder_input["word"] = word
return decoder_input
class EventEncoder(nn.Module):
"""
Encode the Label information in AudioCaps and AudioSet
"""
def __init__(self, emb_dim, vocab_size=527):
super(EventEncoder, self).__init__()
self.label_embedding = nn.Parameter(
torch.randn((vocab_size, emb_dim)), requires_grad=True)
def forward(self, word_idxs):
indices = word_idxs / word_idxs.sum(dim=1, keepdim=True)
embeddings = indices @ self.label_embedding
return embeddings
class EventCondTransformerModel(TransformerModel):
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
if not hasattr(self, "compatible_decoders"):
self.compatible_decoders = (
captioning.models.decoder.EventTransformerDecoder,
)
super().__init__(encoder, decoder, **kwargs)
self.label_encoder = EventEncoder(decoder.emb_dim, 527)
self.train_forward_keys += ["events"]
self.inference_forward_keys += ["events"]
# def seq_forward(self, input_dict):
# cap = input_dict["cap"]
# cap_padding_mask = (cap == self.pad_idx).to(cap.device)
# cap_padding_mask = cap_padding_mask[:, :-1]
# output = self.decoder(
# {
# "word": cap[:, :-1],
# "attn_emb": input_dict["attn_emb"],
# "attn_emb_len": input_dict["attn_emb_len"],
# "cap_padding_mask": cap_padding_mask
# }
# )
# return output
def prepare_decoder_input(self, input_dict, output):
decoder_input = super().prepare_decoder_input(input_dict, output)
decoder_input["events"] = self.label_encoder(input_dict["events"])
return decoder_input
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
t = input_dict["t"]
i = input_dict["sample_idx"]
beam_size = input_dict["beam_size"]
if t == 0:
output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size)
decoder_input["events"] = output_i["events"]
return decoder_input
class KeywordCondTransformerModel(TransformerModel):
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
if not hasattr(self, "compatible_decoders"):
self.compatible_decoders = (
captioning.models.decoder.KeywordProbTransformerDecoder,
)
super().__init__(encoder, decoder, **kwargs)
self.train_forward_keys += ["keyword"]
self.inference_forward_keys += ["keyword"]
def seq_forward(self, input_dict):
cap = input_dict["cap"]
cap_padding_mask = (cap == self.pad_idx).to(cap.device)
cap_padding_mask = cap_padding_mask[:, :-1]
keyword = input_dict["keyword"]
output = self.decoder(
{
"word": cap[:, :-1],
"attn_emb": input_dict["attn_emb"],
"attn_emb_len": input_dict["attn_emb_len"],
"keyword": keyword,
"cap_padding_mask": cap_padding_mask
}
)
return output
def prepare_decoder_input(self, input_dict, output):
decoder_input = super().prepare_decoder_input(input_dict, output)
decoder_input["keyword"] = input_dict["keyword"]
return decoder_input
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
t = input_dict["t"]
i = input_dict["sample_idx"]
beam_size = input_dict["beam_size"]
if t == 0:
output_i["keyword"] = repeat_tensor(input_dict["keyword"][i],
beam_size)
decoder_input["keyword"] = output_i["keyword"]
return decoder_input

View File

@@ -0,0 +1,132 @@
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
def sort_pack_padded_sequence(input, lengths):
sorted_lengths, indices = torch.sort(lengths, descending=True)
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
inv_ix = indices.clone()
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
return tmp, inv_ix
def pad_unsort_packed_sequence(input, inv_ix):
tmp, _ = pad_packed_sequence(input, batch_first=True)
tmp = tmp[inv_ix]
return tmp
def pack_wrapper(module, attn_feats, attn_feat_lens):
packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
if isinstance(module, torch.nn.RNNBase):
return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
else:
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
def generate_length_mask(lens, max_length=None):
lens = torch.as_tensor(lens)
N = lens.size(0)
if max_length is None:
max_length = max(lens)
idxs = torch.arange(max_length).repeat(N).view(N, max_length)
idxs = idxs.to(lens.device)
mask = (idxs < lens.view(-1, 1))
return mask
def mean_with_lens(features, lens):
"""
features: [N, T, ...] (assume the second dimension represents length)
lens: [N,]
"""
lens = torch.as_tensor(lens)
if max(lens) != features.size(1):
max_length = features.size(1)
mask = generate_length_mask(lens, max_length)
else:
mask = generate_length_mask(lens)
mask = mask.to(features.device) # [N, T]
while mask.ndim < features.ndim:
mask = mask.unsqueeze(-1)
feature_mean = features * mask
feature_mean = feature_mean.sum(1)
while lens.ndim < feature_mean.ndim:
lens = lens.unsqueeze(1)
feature_mean = feature_mean / lens.to(features.device)
# feature_mean = features * mask.unsqueeze(-1)
# feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
return feature_mean
def max_with_lens(features, lens):
"""
features: [N, T, ...] (assume the second dimension represents length)
lens: [N,]
"""
lens = torch.as_tensor(lens)
mask = generate_length_mask(lens).to(features.device) # [N, T]
feature_max = features.clone()
feature_max[~mask] = float("-inf")
feature_max, _ = feature_max.max(1)
return feature_max
def repeat_tensor(x, n):
return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
def init(m, method="kaiming"):
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
if method == "kaiming":
nn.init.kaiming_uniform_(m.weight)
elif method == "xavier":
nn.init.xavier_uniform_(m.weight)
else:
raise Exception(f"initialization method {method} not supported")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
if method == "kaiming":
nn.init.kaiming_uniform_(m.weight)
elif method == "xavier":
nn.init.xavier_uniform_(m.weight)
else:
raise Exception(f"initialization method {method} not supported")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
if method == "kaiming":
nn.init.kaiming_uniform_(m.weight)
elif method == "xavier":
nn.init.xavier_uniform_(m.weight)
else:
raise Exception(f"initialization method {method} not supported")
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=100):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
# self.register_buffer("pe", pe)
self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
def forward(self, x):
# x: [T, N, E]
x = x + self.pe[:x.size(0), :]
return self.dropout(x)

View File

@@ -0,0 +1,19 @@
# Utils
Scripts in this directory are used as utility functions.
## BERT Pretrained Embeddings
You can load pretrained word embeddings in Google [BERT](https://github.com/google-research/bert#pre-trained-models) instead of training word embeddings from scratch. The scripts in `utils/bert` need a BERT server in the background. We use BERT server from [bert-as-service](https://github.com/hanxiao/bert-as-service).
To use bert-as-service, you need to first install the repository. It is recommended that you create a new environment with Tensorflow 1.3 to run BERT server since it is incompatible with Tensorflow 2.x.
After successful installation of [bert-as-service](https://github.com/hanxiao/bert-as-service), downloading and running the BERT server needs to execute:
```bash
bash scripts/prepare_bert_server.sh <path-to-server> <num-workers> zh
```
By default, server based on BERT base Chinese model is running in the background. You can change to other models by changing corresponding model name and path in `scripts/prepare_bert_server.sh`.
To extract BERT word embeddings, you need to execute `utils/bert/create_word_embedding.py`.

View File

@@ -0,0 +1,89 @@
import pickle
import fire
import numpy as np
import pandas as pd
from tqdm import tqdm
class EmbeddingExtractor(object):
def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False):
from sentence_transformers import SentenceTransformer
lang2model = {
"zh": "distiluse-base-multilingual-cased",
"en": "bert-base-nli-mean-tokens"
}
lang = "zh" if zh else "en"
model = SentenceTransformer(lang2model[lang])
self.extract(caption_file, model, output, dev)
def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"):
from bert_serving.client import BertClient
client = BertClient(ip)
self.extract(caption_file, client, output, dev)
def extract(self, caption_file: str, model, output, dev: bool):
caption_df = pd.read_json(caption_file, dtype={"key": str})
embeddings = {}
if dev:
with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
for idx, row in caption_df.iterrows():
caption = row["caption"]
key = row["key"]
cap_idx = row["caption_index"]
embedding = model.encode([caption])
embedding = np.array(embedding).reshape(-1)
embeddings[f"{key}_{cap_idx}"] = embedding
pbar.update()
else:
dump = {}
with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
for idx, row in caption_df.iterrows():
key = row["key"]
caption = row["caption"]
value = np.array(model.encode([caption])).reshape(-1)
if key not in embeddings.keys():
embeddings[key] = [value]
else:
embeddings[key].append(value)
pbar.update()
for key in embeddings:
dump[key] = np.stack(embeddings[key])
embeddings = dump
with open(output, "wb") as f:
pickle.dump(embeddings, f)
def extract_sbert(self,
input_json: str,
output: str):
from sentence_transformers import SentenceTransformer
import json
import torch
from h5py import File
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
model = model.to(device)
model.eval()
data = json.load(open(input_json))["audios"]
with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store:
for sample in data:
audio_id = sample["audio_id"]
for cap in sample["captions"]:
cap_id = cap["cap_id"]
store[f"{audio_id}_{cap_id}"] = model.encode(cap["caption"])
pbar.update()
if __name__ == "__main__":
fire.Fire(EmbeddingExtractor)

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
import sys
import os
from bert_serving.client import BertClient
import numpy as np
from tqdm import tqdm
import fire
import torch
sys.path.append(os.getcwd())
from utils.build_vocab import Vocabulary
def main(vocab_file: str, output: str, server_hostname: str):
client = BertClient(ip=server_hostname)
vocabulary = torch.load(vocab_file)
vocab_size = len(vocabulary)
fake_embedding = client.encode(["test"]).reshape(-1)
embed_size = fake_embedding.shape[0]
print("Encoding words into embeddings with size: ", embed_size)
embeddings = np.empty((vocab_size, embed_size))
for i in tqdm(range(len(embeddings)), ascii=True):
embeddings[i] = client.encode([vocabulary.idx2word[i]])
np.save(output, embeddings)
if __name__ == '__main__':
fire.Fire(main)

View File

@@ -0,0 +1,153 @@
import json
from tqdm import tqdm
import logging
import pickle
from collections import Counter
import re
import fire
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __call__(self, word):
if not word in self.word2idx:
return self.word2idx["<unk>"]
return self.word2idx[word]
def __getitem__(self, word_id):
return self.idx2word[word_id]
def __len__(self):
return len(self.word2idx)
def build_vocab(input_json: str,
threshold: int,
keep_punctuation: bool,
host_address: str,
character_level: bool = False,
zh: bool = True ):
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
Args:
input_json(string): Preprossessed json file. Structure like this:
{
'audios': [
{
'audio_id': 'xxx',
'captions': [
{
'caption': 'xxx',
'cap_id': 'xxx'
}
]
},
...
]
}
threshold (int): Threshold to drop all words with counts < threshold
keep_punctuation (bool): Includes or excludes punctuation.
Returns:
vocab (Vocab): Object with the processed vocabulary
"""
data = json.load(open(input_json, "r"))["audios"]
counter = Counter()
pretokenized = "tokens" in data[0]["captions"][0]
if zh:
from nltk.parse.corenlp import CoreNLPParser
from zhon.hanzi import punctuation
if not pretokenized:
parser = CoreNLPParser(host_address)
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
for cap_idx in range(len(data[audio_idx]["captions"])):
if pretokenized:
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
else:
caption = data[audio_idx]["captions"][cap_idx]["caption"]
# Remove all punctuations
if not keep_punctuation:
caption = re.sub("[{}]".format(punctuation), "", caption)
if character_level:
tokens = list(caption)
else:
tokens = list(parser.tokenize(caption))
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
counter.update(tokens)
else:
if pretokenized:
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
for cap_idx in range(len(data[audio_idx]["captions"])):
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
counter.update(tokens)
else:
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
captions = {}
for audio_idx in range(len(data)):
audio_id = data[audio_idx]["audio_id"]
captions[audio_id] = []
for cap_idx in range(len(data[audio_idx]["captions"])):
caption = data[audio_idx]["captions"][cap_idx]["caption"]
captions[audio_id].append({
"audio_id": audio_id,
"id": cap_idx,
"caption": caption
})
tokenizer = PTBTokenizer()
captions = tokenizer.tokenize(captions)
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
audio_id = data[audio_idx]["audio_id"]
for cap_idx in range(len(data[audio_idx]["captions"])):
tokens = captions[audio_id][cap_idx]
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
counter.update(tokens.split(" "))
if not pretokenized:
json.dump({ "audios": data }, open(input_json, "w"), indent=4, ensure_ascii=not zh)
words = [word for word, cnt in counter.items() if cnt >= threshold]
# Create a vocab wrapper and add some special tokens.
vocab = Vocabulary()
vocab.add_word("<pad>")
vocab.add_word("<start>")
vocab.add_word("<end>")
vocab.add_word("<unk>")
# Add the words to the vocabulary.
for word in words:
vocab.add_word(word)
return vocab
def process(input_json: str,
output_file: str,
threshold: int = 1,
keep_punctuation: bool = False,
character_level: bool = False,
host_address: str = "http://localhost:9000",
zh: bool = False):
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=logfmt)
logging.info("Build Vocab")
vocabulary = build_vocab(
input_json=input_json, threshold=threshold, keep_punctuation=keep_punctuation,
host_address=host_address, character_level=character_level, zh=zh)
pickle.dump(vocabulary, open(output_file, "wb"))
logging.info("Total vocabulary size: {}".format(len(vocabulary)))
logging.info("Saved vocab to '{}'".format(output_file))
if __name__ == '__main__':
fire.Fire(process)

View File

@@ -0,0 +1,150 @@
import json
from tqdm import tqdm
import logging
import pickle
from collections import Counter
import re
import fire
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __call__(self, word):
if not word in self.word2idx:
return self.word2idx["<unk>"]
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)
def build_vocab(input_json: str,
output_json: str,
threshold: int,
keep_punctuation: bool,
character_level: bool = False,
zh: bool = True ):
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
Args:
input_json(string): Preprossessed json file. Structure like this:
{
'audios': [
{
'audio_id': 'xxx',
'captions': [
{
'caption': 'xxx',
'cap_id': 'xxx'
}
]
},
...
]
}
threshold (int): Threshold to drop all words with counts < threshold
keep_punctuation (bool): Includes or excludes punctuation.
Returns:
vocab (Vocab): Object with the processed vocabulary
"""
data = json.load(open(input_json, "r"))["audios"]
counter = Counter()
pretokenized = "tokens" in data[0]["captions"][0]
if zh:
from ltp import LTP
from zhon.hanzi import punctuation
if not pretokenized:
parser = LTP("base")
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
for cap_idx in range(len(data[audio_idx]["captions"])):
if pretokenized:
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
else:
caption = data[audio_idx]["captions"][cap_idx]["caption"]
if character_level:
tokens = list(caption)
else:
tokens, _ = parser.seg([caption])
tokens = tokens[0]
# Remove all punctuations
if not keep_punctuation:
tokens = [token for token in tokens if token not in punctuation]
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
counter.update(tokens)
else:
if pretokenized:
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
for cap_idx in range(len(data[audio_idx]["captions"])):
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
counter.update(tokens)
else:
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
captions = {}
for audio_idx in range(len(data)):
audio_id = data[audio_idx]["audio_id"]
captions[audio_id] = []
for cap_idx in range(len(data[audio_idx]["captions"])):
caption = data[audio_idx]["captions"][cap_idx]["caption"]
captions[audio_id].append({
"audio_id": audio_id,
"id": cap_idx,
"caption": caption
})
tokenizer = PTBTokenizer()
captions = tokenizer.tokenize(captions)
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
audio_id = data[audio_idx]["audio_id"]
for cap_idx in range(len(data[audio_idx]["captions"])):
tokens = captions[audio_id][cap_idx]
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
counter.update(tokens.split(" "))
if not pretokenized:
if output_json is None:
output_json = input_json
json.dump({ "audios": data }, open(output_json, "w"), indent=4, ensure_ascii=not zh)
words = [word for word, cnt in counter.items() if cnt >= threshold]
# Create a vocab wrapper and add some special tokens.
vocab = Vocabulary()
vocab.add_word("<pad>")
vocab.add_word("<start>")
vocab.add_word("<end>")
vocab.add_word("<unk>")
# Add the words to the vocabulary.
for word in words:
vocab.add_word(word)
return vocab
def process(input_json: str,
output_file: str,
output_json: str = None,
threshold: int = 1,
keep_punctuation: bool = False,
character_level: bool = False,
zh: bool = True):
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=logfmt)
logging.info("Build Vocab")
vocabulary = build_vocab(
input_json=input_json, output_json=output_json, threshold=threshold,
keep_punctuation=keep_punctuation, character_level=character_level, zh=zh)
pickle.dump(vocabulary, open(output_file, "wb"))
logging.info("Total vocabulary size: {}".format(len(vocabulary)))
logging.info("Saved vocab to '{}'".format(output_file))
if __name__ == '__main__':
fire.Fire(process)

View File

@@ -0,0 +1,152 @@
import json
from tqdm import tqdm
import logging
import pickle
from collections import Counter
import re
import fire
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __call__(self, word):
if not word in self.word2idx:
return self.word2idx["<unk>"]
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)
def build_vocab(input_json: str,
output_json: str,
threshold: int,
keep_punctuation: bool,
host_address: str,
character_level: bool = False,
retokenize: bool = True,
zh: bool = True ):
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
Args:
input_json(string): Preprossessed json file. Structure like this:
{
'audios': [
{
'audio_id': 'xxx',
'captions': [
{
'caption': 'xxx',
'cap_id': 'xxx'
}
]
},
...
]
}
threshold (int): Threshold to drop all words with counts < threshold
keep_punctuation (bool): Includes or excludes punctuation.
Returns:
vocab (Vocab): Object with the processed vocabulary
"""
data = json.load(open(input_json, "r"))["audios"]
counter = Counter()
if retokenize:
pretokenized = False
else:
pretokenized = "tokens" in data[0]["captions"][0]
if zh:
from nltk.parse.corenlp import CoreNLPParser
from zhon.hanzi import punctuation
if not pretokenized:
parser = CoreNLPParser(host_address)
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
for cap_idx in range(len(data[audio_idx]["captions"])):
if pretokenized:
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
else:
caption = data[audio_idx]["captions"][cap_idx]["caption"]
# Remove all punctuations
if not keep_punctuation:
caption = re.sub("[{}]".format(punctuation), "", caption)
if character_level:
tokens = list(caption)
else:
tokens = list(parser.tokenize(caption))
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
counter.update(tokens)
else:
if pretokenized:
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
for cap_idx in range(len(data[audio_idx]["captions"])):
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
counter.update(tokens)
else:
import spacy
tokenizer = spacy.load("en_core_web_sm", disable=["parser", "ner"])
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
captions = data[audio_idx]["captions"]
for cap_idx in range(len(captions)):
caption = captions[cap_idx]["caption"]
doc = tokenizer(caption)
tokens = " ".join([str(token).lower() for token in doc])
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
counter.update(tokens.split(" "))
if not pretokenized:
if output_json is None:
json.dump({ "audios": data }, open(input_json, "w"),
indent=4, ensure_ascii=not zh)
else:
json.dump({ "audios": data }, open(output_json, "w"),
indent=4, ensure_ascii=not zh)
words = [word for word, cnt in counter.items() if cnt >= threshold]
# Create a vocab wrapper and add some special tokens.
vocab = Vocabulary()
vocab.add_word("<pad>")
vocab.add_word("<start>")
vocab.add_word("<end>")
vocab.add_word("<unk>")
# Add the words to the vocabulary.
for word in words:
vocab.add_word(word)
return vocab
def process(input_json: str,
output_file: str,
output_json: str = None,
threshold: int = 1,
keep_punctuation: bool = False,
character_level: bool = False,
retokenize: bool = False,
host_address: str = "http://localhost:9000",
zh: bool = True):
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=logfmt)
logging.info("Build Vocab")
vocabulary = build_vocab(
input_json=input_json, output_json=output_json, threshold=threshold,
keep_punctuation=keep_punctuation, host_address=host_address,
character_level=character_level, retokenize=retokenize, zh=zh)
pickle.dump(vocabulary, open(output_file, "wb"))
logging.info("Total vocabulary size: {}".format(len(vocabulary)))
logging.info("Saved vocab to '{}'".format(output_file))
if __name__ == '__main__':
fire.Fire(process)

View File

@@ -0,0 +1,182 @@
import copy
import json
import numpy as np
import fire
def evaluate_annotation(key2refs, scorer):
if scorer.method() == "Bleu":
scores = np.array([ 0.0 for n in range(4) ])
else:
scores = 0
num_cap_per_audio = len(next(iter(key2refs.values())))
for i in range(num_cap_per_audio):
if i > 0:
for key in key2refs:
key2refs[key].insert(0, res[key][0])
res = { key: [refs.pop(),] for key, refs in key2refs.items() }
score, _ = scorer.compute_score(key2refs, res)
if scorer.method() == "Bleu":
scores += np.array(score)
else:
scores += score
score = scores / num_cap_per_audio
return score
def evaluate_prediction(key2pred, key2refs, scorer):
if scorer.method() == "Bleu":
scores = np.array([ 0.0 for n in range(4) ])
else:
scores = 0
num_cap_per_audio = len(next(iter(key2refs.values())))
for i in range(num_cap_per_audio):
key2refs_i = {}
for key, refs in key2refs.items():
key2refs_i[key] = refs[:i] + refs[i+1:]
score, _ = scorer.compute_score(key2refs_i, key2pred)
if scorer.method() == "Bleu":
scores += np.array(score)
else:
scores += score
score = scores / num_cap_per_audio
return score
class Evaluator(object):
def eval_annotation(self, annotation, output):
captions = json.load(open(annotation, "r"))["audios"]
key2refs = {}
for audio_idx in range(len(captions)):
audio_id = captions[audio_idx]["audio_id"]
key2refs[audio_id] = []
for caption in captions[audio_idx]["captions"]:
key2refs[audio_id].append(caption["caption"])
from fense.fense import Fense
scores = {}
scorer = Fense()
scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer)
refs4eval = {}
for key, refs in key2refs.items():
refs4eval[key] = []
for idx, ref in enumerate(refs):
refs4eval[key].append({
"audio_id": key,
"id": idx,
"caption": ref
})
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
tokenizer = PTBTokenizer()
key2refs = tokenizer.tokenize(refs4eval)
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.spice.spice import Spice
scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()]
for scorer in scorers:
scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer)
spider = 0
with open(output, "w") as f:
for name, score in scores.items():
if name == "Bleu":
for n in range(4):
f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n]))
else:
f.write("{}: {:6.3f}\n".format(name, score))
if name in ["CIDEr", "SPICE"]:
spider += score
f.write("SPIDEr: {:6.3f}\n".format(spider / 2))
def eval_prediction(self, prediction, annotation, output):
ref_captions = json.load(open(annotation, "r"))["audios"]
key2refs = {}
for audio_idx in range(len(ref_captions)):
audio_id = ref_captions[audio_idx]["audio_id"]
key2refs[audio_id] = []
for caption in ref_captions[audio_idx]["captions"]:
key2refs[audio_id].append(caption["caption"])
pred_captions = json.load(open(prediction, "r"))["predictions"]
key2pred = {}
for audio_idx in range(len(pred_captions)):
item = pred_captions[audio_idx]
audio_id = item["filename"]
key2pred[audio_id] = [item["tokens"]]
from fense.fense import Fense
scores = {}
scorer = Fense()
scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer)
refs4eval = {}
for key, refs in key2refs.items():
refs4eval[key] = []
for idx, ref in enumerate(refs):
refs4eval[key].append({
"audio_id": key,
"id": idx,
"caption": ref
})
preds4eval = {}
for key, preds in key2pred.items():
preds4eval[key] = []
for idx, pred in enumerate(preds):
preds4eval[key].append({
"audio_id": key,
"id": idx,
"caption": pred
})
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
tokenizer = PTBTokenizer()
key2refs = tokenizer.tokenize(refs4eval)
key2pred = tokenizer.tokenize(preds4eval)
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.spice.spice import Spice
scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()]
for scorer in scorers:
scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer)
spider = 0
with open(output, "w") as f:
for name, score in scores.items():
if name == "Bleu":
for n in range(4):
f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n]))
else:
f.write("{}: {:6.3f}\n".format(name, score))
if name in ["CIDEr", "SPICE"]:
spider += score
f.write("SPIDEr: {:6.3f}\n".format(spider / 2))
if __name__ == "__main__":
fire.Fire(Evaluator)

View File

@@ -0,0 +1,50 @@
# coding=utf-8
#!/usr/bin/env python3
import numpy as np
import pandas as pd
import torch
from gensim.models import FastText
from tqdm import tqdm
import fire
import sys
import os
sys.path.append(os.getcwd())
from utils.build_vocab import Vocabulary
def create_embedding(caption_file: str,
vocab_file: str,
embed_size: int,
output: str,
**fasttext_kwargs):
caption_df = pd.read_json(caption_file)
caption_df["tokens"] = caption_df["tokens"].apply(lambda x: ["<start>"] + [token for token in x] + ["<end>"])
sentences = list(caption_df["tokens"].values)
vocabulary = torch.load(vocab_file, map_location="cpu")
epochs = fasttext_kwargs.get("epochs", 10)
model = FastText(size=embed_size, min_count=1, **fasttext_kwargs)
model.build_vocab(sentences=sentences)
model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs)
word_embeddings = np.zeros((len(vocabulary), embed_size))
with tqdm(total=len(vocabulary), ascii=True) as pbar:
for word, idx in vocabulary.word2idx.items():
if word == "<pad>" or word == "<unk>":
continue
word_embeddings[idx] = model.wv[word]
pbar.update()
np.save(output, word_embeddings)
print("Finish writing fasttext embeddings to " + output)
if __name__ == "__main__":
fire.Fire(create_embedding)

View File

@@ -0,0 +1,128 @@
import math
import torch
class ExponentialDecayScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, total_iters, final_lrs,
warmup_iters=3000, last_epoch=-1, verbose=False):
self.total_iters = total_iters
self.final_lrs = final_lrs
if not isinstance(self.final_lrs, list) and not isinstance(
self.final_lrs, tuple):
self.final_lrs = [self.final_lrs] * len(optimizer.param_groups)
self.warmup_iters = warmup_iters
self.bases = [0.0,] * len(optimizer.param_groups)
super().__init__(optimizer, last_epoch, verbose)
for i, (base_lr, final_lr) in enumerate(zip(self.base_lrs, self.final_lrs)):
base = (final_lr / base_lr) ** (1 / (
self.total_iters - self.warmup_iters))
self.bases[i] = base
def _get_closed_form_lr(self):
warmup_coeff = 1.0
current_iter = self._step_count
if current_iter < self.warmup_iters:
warmup_coeff = current_iter / self.warmup_iters
current_lrs = []
# if not self.linear_warmup:
# for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs, self.bases):
# # current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr))
# current_lr = warmup_coeff * base_lr * (base ** (current_iter - self.warmup_iters))
# current_lrs.append(current_lr)
# else:
for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs,
self.bases):
if current_iter <= self.warmup_iters:
current_lr = warmup_coeff * base_lr
else:
# current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr))
current_lr = base_lr * (base ** (current_iter - self.warmup_iters))
current_lrs.append(current_lr)
return current_lrs
def get_lr(self):
return self._get_closed_form_lr()
class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, model_size=512, factor=1, warmup_iters=3000,
last_epoch=-1, verbose=False):
self.model_size = model_size
self.warmup_iters = warmup_iters
# self.factors = [group["lr"] / (self.model_size ** (-0.5) * self.warmup_iters ** (-0.5)) for group in optimizer.param_groups]
self.factor = factor
super().__init__(optimizer, last_epoch, verbose)
def _get_closed_form_lr(self):
current_iter = self._step_count
current_lrs = []
for _ in self.base_lrs:
current_lr = self.factor * \
(self.model_size ** (-0.5) * min(current_iter ** (-0.5),
current_iter * self.warmup_iters ** (-1.5)))
current_lrs.append(current_lr)
return current_lrs
def get_lr(self):
return self._get_closed_form_lr()
class CosineWithWarmup(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, total_iters, warmup_iters,
num_cycles=0.5, last_epoch=-1, verbose=False):
self.total_iters = total_iters
self.warmup_iters = warmup_iters
self.num_cycles = num_cycles
super().__init__(optimizer, last_epoch, verbose)
def lr_lambda(self, iteration):
if iteration < self.warmup_iters:
return float(iteration) / float(max(1, self.warmup_iters))
progress = float(iteration - self.warmup_iters) / float(max(1,
self.total_iters - self.warmup_iters))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(
self.num_cycles) * 2.0 * progress)))
def _get_closed_form_lr(self):
current_iter = self._step_count
current_lrs = []
for base_lr in self.base_lrs:
current_lr = base_lr * self.lr_lambda(current_iter)
current_lrs.append(current_lr)
return current_lrs
def get_lr(self):
return self._get_closed_form_lr()
if __name__ == "__main__":
model = torch.nn.Linear(10, 5)
optimizer = torch.optim.Adam(model.parameters(), 5e-4)
epochs = 25
iters = 600
scheduler = CosineWithWarmup(optimizer, 600 * 25, 600 * 5,)
# scheduler = ExponentialDecayScheduler(optimizer, 600 * 25, 5e-7, 600 * 5)
criterion = torch.nn.MSELoss()
lrs = []
for epoch in range(1, epochs + 1):
for iteration in range(1, iters + 1):
optimizer.zero_grad()
x = torch.randn(4, 10)
y = torch.randn(4, 5)
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
scheduler.step()
# print(f"lr: {scheduler.get_last_lr()}")
# lrs.append(scheduler.get_last_lr())
lrs.append(optimizer.param_groups[0]["lr"])
import matplotlib.pyplot as plt
plt.plot(list(range(1, len(lrs) + 1)), lrs, '-o', markersize=1)
# plt.legend(loc="best")
plt.xlabel("Iteration")
plt.ylabel("LR")
plt.savefig("lr_curve.png", dpi=100)

View File

@@ -0,0 +1,110 @@
import os
import sys
import copy
import pickle
import numpy as np
import pandas as pd
import fire
sys.path.append(os.getcwd())
def coco_score(refs, pred, scorer):
if scorer.method() == "Bleu":
scores = np.array([ 0.0 for n in range(4) ])
else:
scores = 0
num_cap_per_audio = len(refs[list(refs.keys())[0]])
for i in range(num_cap_per_audio):
if i > 0:
for key in refs:
refs[key].insert(0, res[key][0])
res = {key: [refs[key].pop(),] for key in refs}
score, _ = scorer.compute_score(refs, pred)
if scorer.method() == "Bleu":
scores += np.array(score)
else:
scores += score
score = scores / num_cap_per_audio
for key in refs:
refs[key].insert(0, res[key][0])
score_allref, _ = scorer.compute_score(refs, pred)
diff = score_allref - score
return diff
def embedding_score(refs, pred, scorer):
num_cap_per_audio = len(refs[list(refs.keys())[0]])
scores = 0
for i in range(num_cap_per_audio):
res = {key: [refs[key][i],] for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
refs_i = {key: np.concatenate([refs[key][:i], refs[key][i+1:]]) for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
score, _ = scorer.compute_score(refs_i, pred)
scores += score
score = scores / num_cap_per_audio
score_allref, _ = scorer.compute_score(refs, pred)
diff = score_allref - score
return diff
def main(output_file, eval_caption_file, eval_embedding_file, output, zh=False):
output_df = pd.read_json(output_file)
output_df["key"] = output_df["filename"].apply(lambda x: os.path.splitext(os.path.basename(x))[0])
pred = output_df.groupby("key")["tokens"].apply(list).to_dict()
label_df = pd.read_json(eval_caption_file)
if zh:
refs = label_df.groupby("key")["tokens"].apply(list).to_dict()
else:
refs = label_df.groupby("key")["caption"].apply(list).to_dict()
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.rouge.rouge import Rouge
scorer = Bleu(zh=zh)
bleu_scores = coco_score(copy.deepcopy(refs), pred, scorer)
scorer = Cider(zh=zh)
cider_score = coco_score(copy.deepcopy(refs), pred, scorer)
scorer = Rouge(zh=zh)
rouge_score = coco_score(copy.deepcopy(refs), pred, scorer)
if not zh:
from pycocoevalcap.meteor.meteor import Meteor
scorer = Meteor()
meteor_score = coco_score(copy.deepcopy(refs), pred, scorer)
from pycocoevalcap.spice.spice import Spice
scorer = Spice()
spice_score = coco_score(copy.deepcopy(refs), pred, scorer)
# from audiocaptioneval.sentbert.sentencebert import SentenceBert
# scorer = SentenceBert(zh=zh)
# with open(eval_embedding_file, "rb") as f:
# ref_embeddings = pickle.load(f)
# sent_bert = embedding_score(ref_embeddings, pred, scorer)
with open(output, "w") as f:
f.write("Diff:\n")
for n in range(4):
f.write("BLEU-{}: {:6.3f}\n".format(n+1, bleu_scores[n]))
f.write("CIDEr: {:6.3f}\n".format(cider_score))
f.write("ROUGE: {:6.3f}\n".format(rouge_score))
if not zh:
f.write("Meteor: {:6.3f}\n".format(meteor_score))
f.write("SPICE: {:6.3f}\n".format(spice_score))
# f.write("SentenceBert: {:6.3f}\n".format(sent_bert))
if __name__ == "__main__":
fire.Fire(main)

View File

@@ -0,0 +1,49 @@
import json
import random
import argparse
import numpy as np
from tqdm import tqdm
from h5py import File
import sklearn.metrics
random.seed(1)
parser = argparse.ArgumentParser()
parser.add_argument("train_feature", type=str)
parser.add_argument("train_corpus", type=str)
parser.add_argument("pred_feature", type=str)
parser.add_argument("output_json", type=str)
args = parser.parse_args()
train_embs = []
train_idx_to_audioid = []
with File(args.train_feature, "r") as store:
for audio_id, embedding in tqdm(store.items(), ascii=True):
train_embs.append(embedding[()])
train_idx_to_audioid.append(audio_id)
train_annotation = json.load(open(args.train_corpus, "r"))["audios"]
train_audioid_to_tokens = {}
for item in train_annotation:
audio_id = item["audio_id"]
train_audioid_to_tokens[audio_id] = [cap_item["tokens"] for cap_item in item["captions"]]
train_embs = np.stack(train_embs)
pred_data = []
pred_embs = []
pred_idx_to_audioids = []
with File(args.pred_feature, "r") as store:
for audio_id, embedding in tqdm(store.items(), ascii=True):
pred_embs.append(embedding[()])
pred_idx_to_audioids.append(audio_id)
pred_embs = np.stack(pred_embs)
similarity = sklearn.metrics.pairwise.cosine_similarity(pred_embs, train_embs)
for idx, audio_id in enumerate(pred_idx_to_audioids):
train_idx = similarity[idx].argmax()
pred_data.append({
"filename": audio_id,
"tokens": random.choice(train_audioid_to_tokens[train_idx_to_audioid[train_idx]])
})
json.dump({"predictions": pred_data}, open(args.output_json, "w"), ensure_ascii=False, indent=4)

View File

@@ -0,0 +1,18 @@
import argparse
import torch
def main(checkpoint):
state_dict = torch.load(checkpoint, map_location="cpu")
if "optimizer" in state_dict:
del state_dict["optimizer"]
if "lr_scheduler" in state_dict:
del state_dict["lr_scheduler"]
torch.save(state_dict, checkpoint)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("checkpoint", type=str)
args = parser.parse_args()
main(args.checkpoint)

View File

@@ -0,0 +1,37 @@
from pathlib import Path
import argparse
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument("--input", help="input filename", type=str, nargs="+")
parser.add_argument("--output", help="output result file", default=None)
args = parser.parse_args()
scores = {}
for path in args.input:
with open(path, "r") as reader:
for line in reader.readlines():
metric, score = line.strip().split(": ")
score = float(score)
if metric not in scores:
scores[metric] = []
scores[metric].append(score)
if len(scores) == 0:
print("No experiment directory found, wrong path?")
exit(1)
with open(args.output, "w") as writer:
print("Average results: ", file=writer)
for metric, score in scores.items():
score = np.array(score)
mean = np.mean(score)
std = np.std(score)
print(f"{metric}: {mean:.3f}{std:.3f})", file=writer)
print("", file=writer)
print("Best results: ", file=writer)
for metric, score in scores.items():
score = np.max(score)
print(f"{metric}: {score:.3f}", file=writer)

View File

@@ -0,0 +1,86 @@
import json
from tqdm import tqdm
import re
import fire
def tokenize_caption(input_json: str,
keep_punctuation: bool = False,
host_address: str = None,
character_level: bool = False,
zh: bool = True,
output_json: str = None):
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
Args:
input_json(string): Preprossessed json file. Structure like this:
{
'audios': [
{
'audio_id': 'xxx',
'captions': [
{
'caption': 'xxx',
'cap_id': 'xxx'
}
]
},
...
]
}
threshold (int): Threshold to drop all words with counts < threshold
keep_punctuation (bool): Includes or excludes punctuation.
Returns:
vocab (Vocab): Object with the processed vocabulary
"""
data = json.load(open(input_json, "r"))["audios"]
if zh:
from nltk.parse.corenlp import CoreNLPParser
from zhon.hanzi import punctuation
parser = CoreNLPParser(host_address)
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
for cap_idx in range(len(data[audio_idx]["captions"])):
caption = data[audio_idx]["captions"][cap_idx]["caption"]
# Remove all punctuations
if not keep_punctuation:
caption = re.sub("[{}]".format(punctuation), "", caption)
if character_level:
tokens = list(caption)
else:
tokens = list(parser.tokenize(caption))
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
else:
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
captions = {}
for audio_idx in range(len(data)):
audio_id = data[audio_idx]["audio_id"]
captions[audio_id] = []
for cap_idx in range(len(data[audio_idx]["captions"])):
caption = data[audio_idx]["captions"][cap_idx]["caption"]
captions[audio_id].append({
"audio_id": audio_id,
"id": cap_idx,
"caption": caption
})
tokenizer = PTBTokenizer()
captions = tokenizer.tokenize(captions)
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
audio_id = data[audio_idx]["audio_id"]
for cap_idx in range(len(data[audio_idx]["captions"])):
tokens = captions[audio_id][cap_idx]
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
if output_json:
json.dump(
{ "audios": data }, open(output_json, "w"),
indent=4, ensure_ascii=not zh)
else:
json.dump(
{ "audios": data }, open(input_json, "w"),
indent=4, ensure_ascii=not zh)
if __name__ == "__main__":
fire.Fire(tokenize_caption)

View File

@@ -0,0 +1,178 @@
# -*- coding: utf-8 -*-
#!/usr/bin/env python3
import os
import sys
import logging
from typing import Callable, Dict, Union
import yaml
import torch
from torch.optim.swa_utils import AveragedModel as torch_average_model
import numpy as np
import pandas as pd
from pprint import pformat
def load_dict_from_csv(csv, cols):
df = pd.read_csv(csv, sep="\t")
output = dict(zip(df[cols[0]], df[cols[1]]))
return output
def init_logger(filename, level="INFO"):
formatter = logging.Formatter(
"[ %(levelname)s : %(asctime)s ] - %(message)s")
logger = logging.getLogger(__name__ + "." + filename)
logger.setLevel(getattr(logging, level))
# Log results to std
# stdhandler = logging.StreamHandler(sys.stdout)
# stdhandler.setFormatter(formatter)
# Dump log to file
filehandler = logging.FileHandler(filename)
filehandler.setFormatter(formatter)
logger.addHandler(filehandler)
# logger.addHandler(stdhandler)
return logger
def init_obj(module, config, **kwargs):# 'captioning.models.encoder'
obj_args = config["args"].copy()
obj_args.update(kwargs)
return getattr(module, config["type"])(**obj_args)
def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'):
"""pprint_dict
:param outputfun: function to use, defaults to sys.stdout
:param in_dict: dict to print
"""
if formatter == 'yaml':
format_fun = yaml.dump
elif formatter == 'pretty':
format_fun = pformat
for line in format_fun(in_dict).split('\n'):
outputfun(line)
def merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
def load_config(config_file):
with open(config_file, "r") as reader:
config = yaml.load(reader, Loader=yaml.FullLoader)
if "inherit_from" in config:
base_config_file = config["inherit_from"]
base_config_file = os.path.join(
os.path.dirname(config_file), base_config_file
)
assert not os.path.samefile(config_file, base_config_file), \
"inherit from itself"
base_config = load_config(base_config_file)
del config["inherit_from"]
merge_a_into_b(config, base_config)
return base_config
return config
def parse_config_or_kwargs(config_file, **kwargs):
yaml_config = load_config(config_file)
# passed kwargs will override yaml config
args = dict(yaml_config, **kwargs)
return args
def store_yaml(config, config_file):
with open(config_file, "w") as con_writer:
yaml.dump(config, con_writer, indent=4, default_flow_style=False)
class MetricImprover:
def __init__(self, mode):
assert mode in ("min", "max")
self.mode = mode
# min: lower -> better; max: higher -> better
self.best_value = np.inf if mode == "min" else -np.inf
def compare(self, x, best_x):
return x < best_x if self.mode == "min" else x > best_x
def __call__(self, x):
if self.compare(x, self.best_value):
self.best_value = x
return True
return False
def state_dict(self):
return self.__dict__
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def fix_batchnorm(model: torch.nn.Module):
def inner(module):
class_name = module.__class__.__name__
if class_name.find("BatchNorm") != -1:
module.eval()
model.apply(inner)
def load_pretrained_model(model: torch.nn.Module,
pretrained: Union[str, Dict],
output_fn: Callable = sys.stdout.write):
if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
output_fn(f"pretrained {pretrained} not exist!")
return
if hasattr(model, "load_pretrained"):
model.load_pretrained(pretrained)
return
if isinstance(pretrained, dict):
state_dict = pretrained
else:
state_dict = torch.load(pretrained, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
model_dict = model.state_dict()
pretrained_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (
model_dict[k].shape == v.shape)
}
output_fn(f"Loading pretrained keys {pretrained_dict.keys()}")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict=True)
class AveragedModel(torch_average_model):
def update_parameters(self, model):
for p_swa, p_model in zip(self.parameters(), model.parameters()):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
for b_swa, b_model in zip(list(self.buffers())[1:], model.buffers()):
device = b_swa.device
b_model_ = b_model.detach().to(device)
if self.n_averaged == 0:
b_swa.detach().copy_(b_model_)
else:
b_swa.detach().copy_(self.avg_fn(b_swa.detach(), b_model_,
self.n_averaged.to(device)))
self.n_averaged += 1

View File

@@ -0,0 +1,67 @@
# coding=utf-8
#!/usr/bin/env python3
import numpy as np
import pandas as pd
import torch
import gensim
from gensim.models import Word2Vec
from tqdm import tqdm
import fire
import sys
import os
sys.path.append(os.getcwd())
from utils.build_vocab import Vocabulary
def create_embedding(vocab_file: str,
embed_size: int,
output: str,
caption_file: str = None,
pretrained_weights_path: str = None,
**word2vec_kwargs):
vocabulary = torch.load(vocab_file, map_location="cpu")
if pretrained_weights_path:
model = gensim.models.KeyedVectors.load_word2vec_format(
fname=pretrained_weights_path,
binary=True,
)
if model.vector_size != embed_size:
assert embed_size < model.vector_size, f"only reduce dimension, cannot add dimesion {model.vector_size} to {embed_size}"
from sklearn.decomposition import PCA
pca = PCA(n_components=embed_size)
model.vectors = pca.fit_transform(model.vectors)
else:
caption_df = pd.read_json(caption_file)
caption_df["tokens"] = caption_df["tokens"].apply(lambda x: ["<start>"] + [token for token in x] + ["<end>"])
sentences = list(caption_df["tokens"].values)
epochs = word2vec_kwargs.get("epochs", 10)
if "epochs" in word2vec_kwargs:
del word2vec_kwargs["epochs"]
model = Word2Vec(size=embed_size, min_count=1, **word2vec_kwargs)
model.build_vocab(sentences=sentences)
model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs)
word_embeddings = np.random.randn(len(vocabulary), embed_size)
if isinstance(model, gensim.models.word2vec.Word2Vec):
model = model.wv
with tqdm(total=len(vocabulary), ascii=True) as pbar:
for word, idx in vocabulary.word2idx.items():
try:
word_embeddings[idx] = model.get_vector(word)
except KeyError:
print(f"word {word} not found in word2vec model, it is random initialized!")
pbar.update()
np.save(output, word_embeddings)
print("Finish writing word2vec embeddings to " + output)
if __name__ == "__main__":
fire.Fire(create_embedding)

View File

@@ -0,0 +1,102 @@
import sys
import os
import librosa
import numpy as np
import torch
import audio_to_text.captioning.models
import audio_to_text.captioning.models.encoder
import audio_to_text.captioning.models.decoder
import audio_to_text.captioning.utils.train_util as train_util
def load_model(config, checkpoint):
ckpt = torch.load(checkpoint, "cpu")
encoder_cfg = config["model"]["encoder"]
encoder = train_util.init_obj(
audio_to_text.captioning.models.encoder,
encoder_cfg
)
if "pretrained" in encoder_cfg:
pretrained = encoder_cfg["pretrained"]
train_util.load_pretrained_model(encoder,
pretrained,
sys.stdout.write)
decoder_cfg = config["model"]["decoder"]
if "vocab_size" not in decoder_cfg["args"]:
decoder_cfg["args"]["vocab_size"] = len(ckpt["vocabulary"])
decoder = train_util.init_obj(
audio_to_text.captioning.models.decoder,
decoder_cfg
)
if "word_embedding" in decoder_cfg:
decoder.load_word_embedding(**decoder_cfg["word_embedding"])
if "pretrained" in decoder_cfg:
pretrained = decoder_cfg["pretrained"]
train_util.load_pretrained_model(decoder,
pretrained,
sys.stdout.write)
model = train_util.init_obj(audio_to_text.captioning.models, config["model"],
encoder=encoder, decoder=decoder)
train_util.load_pretrained_model(model, ckpt)
model.eval()
return {
"model": model,
"vocabulary": ckpt["vocabulary"]
}
def decode_caption(word_ids, vocabulary):
candidate = []
for word_id in word_ids:
word = vocabulary[word_id]
if word == "<end>":
break
elif word == "<start>":
continue
candidate.append(word)
candidate = " ".join(candidate)
return candidate
class AudioCapModel(object):
def __init__(self,weight_dir,device='cuda'):
config = os.path.join(weight_dir,'config.yaml')
self.config = train_util.parse_config_or_kwargs(config)
checkpoint = os.path.join(weight_dir,'swa.pth')
resumed = load_model(self.config, checkpoint)
model = resumed["model"]
self.vocabulary = resumed["vocabulary"]
self.model = model.to(device)
self.device = device
def caption(self,audio_list):
if isinstance(audio_list,np.ndarray):
audio_list = [audio_list]
elif isinstance(audio_list,str):
audio_list = [librosa.load(audio_list,sr=32000)[0]]
captions = []
for wav in audio_list:
inputwav = torch.as_tensor(wav).float().unsqueeze(0).to(self.device)
wav_len = torch.LongTensor([len(wav)])
input_dict = {
"mode": "inference",
"wav": inputwav,
"wav_len": wav_len,
"specaug": False,
"sample_method": "beam",
}
print(input_dict)
out_dict = self.model(input_dict)
caption_batch = [decode_caption(seq, self.vocabulary) for seq in \
out_dict["seq"].cpu().numpy()]
captions.extend(caption_batch)
return captions
def __call__(self, audio_list):
return self.caption(audio_list)

View File

@@ -26,4 +26,9 @@ wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/Gene
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/spk_map.json
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/train_f0s_mean_std.npy
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/word_set.json
wget -P text_to_speech/checkpoints/hifi_lj -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/text_to_speech/checkpoints/hifi_lj/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/text_to_speech/checkpoints/hifi_lj/model_ckpt_steps_2076000.ckpt
wget -P text_to_speech/checkpoints/ljspeech/ps_adv_baseline -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/text_to_speech/checkpoints/ljspeech/ps_adv_baseline/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/checkpoints/ljspeech/ps_adv_baseline/model_ckpt_steps_160000.ckpt https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/checkpoints/ljspeech/ps_adv_baseline/model_ckpt_steps_160001.ckpt
# Audio to text
wget -P audio_to_text/audiocaps_cntrstv_cnn14rnn_trm -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/swa.pth
wget -P audio_to_text/clotho_cntrstv_cnn14rnn_trm -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/audio_to_text/clotho_cntrstv_cnn14rnn_trm/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/clotho_cntrstv_cnn14rnn_trm/swa.pth
wget -P audio_to_text/pretrained_feature_extractors https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth