diff --git a/assets/2bf90e35.wav b/assets/2bf90e35.wav new file mode 100644 index 0000000..1230dad Binary files /dev/null and b/assets/2bf90e35.wav differ diff --git a/assets/5d67d1b9.wav b/assets/5d67d1b9.wav new file mode 100644 index 0000000..46beb7e Binary files /dev/null and b/assets/5d67d1b9.wav differ diff --git a/assets/README.md b/assets/README.md index 7c83ea3..1cf23e0 100644 --- a/assets/README.md +++ b/assets/README.md @@ -7,21 +7,45 @@ Output:
Input Example : Generate an audio of a piano playing
Output:
![](t2a.png)
+Audio:
+
+ +## Text-To-Speech +Input Example : Generate a speech with text "here we go"
+Output:
+![](tts.png)
+Audio:
+
+ ## Text-To-Sing Input example : please generate a piece of singing voice. Text sequence is 小酒窝长睫毛AP是你最美的记号. Note sequence is C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4. Note duration sequence is 0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340.
Output:
![](t2s.png)
+Audio:
+
## Image-To-Audio First upload your image(.png)
Input Example : Generate the audio of this image
Output:
![](i2a-2.png)
-## ASR +Audio:
+
+ +## Speech Recognition First upload your audio(.wav)
-Input Example : Generate the text of this audio
+Audio Example :
+
+Input Example : Generate the text of this speech
Output:
![](asr.png)
+## Audio-To-Text +First upload your audio(.wav)
+Audio Example :
+
+Input Example : Please tell me the text description of this audio.
+Output:
+![](a2i.png)
## Style Transfer Text-To-Speech First upload your audio(.wav)
Input Example : Speak using the voice of this audio. The text is "here we go".
diff --git a/assets/Track 4.wav b/assets/Track 4.wav new file mode 100644 index 0000000..64a19a1 Binary files /dev/null and b/assets/Track 4.wav differ diff --git a/assets/a-group-of-sheep-are-baaing.wav b/assets/a-group-of-sheep-are-baaing.wav new file mode 100644 index 0000000..2c5300a Binary files /dev/null and b/assets/a-group-of-sheep-are-baaing.wav differ diff --git a/assets/a2i.png b/assets/a2i.png new file mode 100644 index 0000000..cbca66b Binary files /dev/null and b/assets/a2i.png differ diff --git a/assets/b973e878.wav b/assets/b973e878.wav new file mode 100644 index 0000000..8abd02d Binary files /dev/null and b/assets/b973e878.wav differ diff --git a/assets/fd5cf55e.wav b/assets/fd5cf55e.wav new file mode 100644 index 0000000..bbc6c76 Binary files /dev/null and b/assets/fd5cf55e.wav differ diff --git a/assets/tts.png b/assets/tts.png new file mode 100644 index 0000000..871f27a Binary files /dev/null and b/assets/tts.png differ diff --git a/audio_to_text/__init__.py b/audio_to_text/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/audio_to_text/captioning/__init__.py b/audio_to_text/captioning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/audio_to_text/captioning/models/__init__.py b/audio_to_text/captioning/models/__init__.py new file mode 100644 index 0000000..7259d67 --- /dev/null +++ b/audio_to_text/captioning/models/__init__.py @@ -0,0 +1,3 @@ +from .base_model import * +from .transformer_model import * + diff --git a/audio_to_text/captioning/models/base_model.py b/audio_to_text/captioning/models/base_model.py new file mode 100644 index 0000000..cd014e9 --- /dev/null +++ b/audio_to_text/captioning/models/base_model.py @@ -0,0 +1,500 @@ +# -*- coding: utf-8 -*- + +from typing import Dict + +import torch +import torch.nn as nn + +from .utils import mean_with_lens, repeat_tensor + + +class CaptionModel(nn.Module): + """ + Encoder-decoder captioning model. + """ + + pad_idx = 0 + start_idx = 1 + end_idx = 2 + max_length = 20 + + def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.vocab_size = decoder.vocab_size + self.train_forward_keys = ["cap", "cap_len", "ss_ratio"] + self.inference_forward_keys = ["sample_method", "max_length", "temp"] + freeze_encoder = kwargs.get("freeze_encoder", False) + if freeze_encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + self.check_decoder_compatibility() + + def check_decoder_compatibility(self): + compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders] + assert isinstance(self.decoder, self.compatible_decoders), \ + f"{self.decoder.__class__.__name__} is incompatible with " \ + f"{self.__class__.__name__}, please use decoder in {compatible_decoders} " + + @classmethod + def set_index(cls, start_idx, end_idx): + cls.start_idx = start_idx + cls.end_idx = end_idx + + def forward(self, input_dict: Dict): + """ + input_dict: { + (required) + mode: train/inference, + spec, + spec_len, + fc, + attn, + attn_len, + [sample_method: greedy], + [temp: 1.0] (in case of no teacher forcing) + + (optional, mode=train) + cap, + cap_len, + ss_ratio, + + (optional, mode=inference) + sample_method: greedy/beam, + max_length, + temp, + beam_size (optional, sample_method=beam), + n_best (optional, sample_method=beam), + } + """ + # encoder_input_keys = ["spec", "spec_len", "fc", "attn", "attn_len"] + # encoder_input = { key: input_dict[key] for key in encoder_input_keys } + encoder_output_dict = self.encoder(input_dict) + if input_dict["mode"] == "train": + forward_dict = { + "mode": "train", "sample_method": "greedy", "temp": 1.0 + } + for key in self.train_forward_keys: + forward_dict[key] = input_dict[key] + forward_dict.update(encoder_output_dict) + output = self.train_forward(forward_dict) + elif input_dict["mode"] == "inference": + forward_dict = {"mode": "inference"} + default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 } + for key in self.inference_forward_keys: + if key in input_dict: + forward_dict[key] = input_dict[key] + else: + forward_dict[key] = default_args[key] + + if forward_dict["sample_method"] == "beam": + forward_dict["beam_size"] = input_dict.get("beam_size", 3) + forward_dict["n_best"] = input_dict.get("n_best", False) + forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"]) + elif forward_dict["sample_method"] == "dbs": + forward_dict["beam_size"] = input_dict.get("beam_size", 6) + forward_dict["group_size"] = input_dict.get("group_size", 3) + forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5) + forward_dict["group_nbest"] = input_dict.get("group_nbest", True) + + forward_dict.update(encoder_output_dict) + output = self.inference_forward(forward_dict) + else: + raise Exception("mode should be either 'train' or 'inference'") + + return output + + def prepare_output(self, input_dict): + output = {} + batch_size = input_dict["fc_emb"].size(0) + if input_dict["mode"] == "train": + max_length = input_dict["cap"].size(1) - 1 + elif input_dict["mode"] == "inference": + max_length = input_dict["max_length"] + else: + raise Exception("mode should be either 'train' or 'inference'") + device = input_dict["fc_emb"].device + output["seq"] = torch.full((batch_size, max_length), self.end_idx, + dtype=torch.long) + output["logit"] = torch.empty(batch_size, max_length, + self.vocab_size).to(device) + output["sampled_logprob"] = torch.zeros(batch_size, max_length) + output["embed"] = torch.empty(batch_size, max_length, + self.decoder.d_model).to(device) + return output + + def train_forward(self, input_dict): + if input_dict["ss_ratio"] != 1: # scheduled sampling training + input_dict["mode"] = "train" + return self.stepwise_forward(input_dict) + output = self.seq_forward(input_dict) + self.train_process(output, input_dict) + return output + + def seq_forward(self, input_dict): + raise NotImplementedError + + def train_process(self, output, input_dict): + pass + + def inference_forward(self, input_dict): + if input_dict["sample_method"] == "beam": + return self.beam_search(input_dict) + elif input_dict["sample_method"] == "dbs": + return self.diverse_beam_search(input_dict) + return self.stepwise_forward(input_dict) + + def stepwise_forward(self, input_dict): + """Step-by-step decoding""" + output = self.prepare_output(input_dict) + max_length = output["seq"].size(1) + # start sampling + for t in range(max_length): + input_dict["t"] = t + self.decode_step(input_dict, output) + if input_dict["mode"] == "inference": # decide whether to stop when sampling + unfinished_t = output["seq"][:, t] != self.end_idx + if t == 0: + unfinished = unfinished_t + else: + unfinished *= unfinished_t + output["seq"][:, t][~unfinished] = self.end_idx + if unfinished.sum() == 0: + break + self.stepwise_process(output) + return output + + def decode_step(self, input_dict, output): + """Decoding operation of timestep t""" + decoder_input = self.prepare_decoder_input(input_dict, output) + # feed to the decoder to get logit + output_t = self.decoder(decoder_input) + logit_t = output_t["logit"] + # assert logit_t.ndim == 3 + if logit_t.size(1) == 1: + logit_t = logit_t.squeeze(1) + embed_t = output_t["embed"].squeeze(1) + elif logit_t.size(1) > 1: + logit_t = logit_t[:, -1, :] + embed_t = output_t["embed"][:, -1, :] + else: + raise Exception("no logit output") + # sample the next input word and get the corresponding logit + sampled = self.sample_next_word(logit_t, + method=input_dict["sample_method"], + temp=input_dict["temp"]) + + output_t.update(sampled) + output_t["t"] = input_dict["t"] + output_t["logit"] = logit_t + output_t["embed"] = embed_t + self.stepwise_process_step(output, output_t) + + def prepare_decoder_input(self, input_dict, output): + """Prepare the inp ut dict for the decoder""" + raise NotImplementedError + + def stepwise_process_step(self, output, output_t): + """Postprocessing (save output values) after each timestep t""" + t = output_t["t"] + output["logit"][:, t, :] = output_t["logit"] + output["seq"][:, t] = output_t["word"] + output["sampled_logprob"][:, t] = output_t["probs"] + output["embed"][:, t, :] = output_t["embed"] + + def stepwise_process(self, output): + """Postprocessing after the whole step-by-step autoregressive decoding""" + pass + + def sample_next_word(self, logit, method, temp): + """Sample the next word, given probs output by the decoder""" + logprob = torch.log_softmax(logit, dim=1) + if method == "greedy": + sampled_logprob, word = torch.max(logprob.detach(), 1) + elif method == "gumbel": + def sample_gumbel(shape, eps=1e-20): + U = torch.rand(shape).to(logprob.device) + return -torch.log(-torch.log(U + eps) + eps) + def gumbel_softmax_sample(logit, temperature): + y = logit + sample_gumbel(logit.size()) + return torch.log_softmax(y / temperature, dim=-1) + _logprob = gumbel_softmax_sample(logprob, temp) + _, word = torch.max(_logprob.data, 1) + sampled_logprob = logprob.gather(1, word.unsqueeze(-1)) + else: + logprob = logprob / temp + if method.startswith("top"): + top_num = float(method[3:]) + if 0 < top_num < 1: # top-p sampling + probs = torch.softmax(logit, dim=1) + sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) + _cumsum = sorted_probs.cumsum(1) + mask = _cumsum < top_num + mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1) + sorted_probs = sorted_probs * mask.to(sorted_probs) + sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) + logprob.scatter_(1, sorted_indices, sorted_probs.log()) + else: # top-k sampling + k = int(top_num) + tmp = torch.empty_like(logprob).fill_(float('-inf')) + topk, indices = torch.topk(logprob, k, dim=1) + tmp = tmp.scatter(1, indices, topk) + logprob = tmp + word = torch.distributions.Categorical(logits=logprob.detach()).sample() + sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1) + word = word.detach().long() + # sampled_logprob: [N,], word: [N,] + return {"word": word, "probs": sampled_logprob} + + def beam_search(self, input_dict): + output = self.prepare_output(input_dict) + max_length = input_dict["max_length"] + beam_size = input_dict["beam_size"] + if input_dict["n_best"]: + n_best_size = input_dict["n_best_size"] + batch_size, max_length = output["seq"].size() + output["seq"] = torch.full((batch_size, n_best_size, max_length), + self.end_idx, dtype=torch.long) + + temp = input_dict["temp"] + # instance by instance beam seach + for i in range(output["seq"].size(0)): + output_i = self.prepare_beamsearch_output(input_dict) + input_dict["sample_idx"] = i + for t in range(max_length): + input_dict["t"] = t + output_t = self.beamsearch_step(input_dict, output_i) + ####################################### + # merge with previous beam and select the current max prob beam + ####################################### + logit_t = output_t["logit"] + if logit_t.size(1) == 1: + logit_t = logit_t.squeeze(1) + elif logit_t.size(1) > 1: + logit_t = logit_t[:, -1, :] + else: + raise Exception("no logit output") + logprob_t = torch.log_softmax(logit_t, dim=1) + logprob_t = torch.log_softmax(logprob_t / temp, dim=1) + logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t + if t == 0: # for the first step, all k seq will have the same probs + topk_logprob, topk_words = logprob_t[0].topk( + beam_size, 0, True, True) + else: # unroll and find top logprob, and their unrolled indices + topk_logprob, topk_words = logprob_t.view(-1).topk( + beam_size, 0, True, True) + topk_words = topk_words.cpu() + output_i["topk_logprob"] = topk_logprob + # output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,] + output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size, + rounding_mode='trunc') + output_i["next_word"] = topk_words % self.vocab_size # [beam_size,] + if t == 0: + output_i["seq"] = output_i["next_word"].unsqueeze(1) + else: + output_i["seq"] = torch.cat([ + output_i["seq"][output_i["prev_words_beam"]], + output_i["next_word"].unsqueeze(1)], dim=1) + + # add finished beams to results + is_end = output_i["next_word"] == self.end_idx + if t == max_length - 1: + is_end.fill_(1) + + for beam_idx in range(beam_size): + if is_end[beam_idx]: + final_beam = { + "seq": output_i["seq"][beam_idx].clone(), + "score": output_i["topk_logprob"][beam_idx].item() + } + final_beam["score"] = final_beam["score"] / (t + 1) + output_i["done_beams"].append(final_beam) + output_i["topk_logprob"][is_end] -= 1000 + + self.beamsearch_process_step(output_i, output_t) + + self.beamsearch_process(output, output_i, input_dict) + return output + + def prepare_beamsearch_output(self, input_dict): + beam_size = input_dict["beam_size"] + device = input_dict["fc_emb"].device + output = { + "topk_logprob": torch.zeros(beam_size).to(device), + "seq": None, + "prev_words_beam": None, + "next_word": None, + "done_beams": [], + } + return output + + def beamsearch_step(self, input_dict, output_i): + decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i) + output_t = self.decoder(decoder_input) + output_t["t"] = input_dict["t"] + return output_t + + def prepare_beamsearch_decoder_input(self, input_dict, output_i): + raise NotImplementedError + + def beamsearch_process_step(self, output_i, output_t): + pass + + def beamsearch_process(self, output, output_i, input_dict): + i = input_dict["sample_idx"] + done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"]) + if input_dict["n_best"]: + done_beams = done_beams[:input_dict["n_best_size"]] + for out_idx, done_beam in enumerate(done_beams): + seq = done_beam["seq"] + output["seq"][i][out_idx, :len(seq)] = seq + else: + seq = done_beams[0]["seq"] + output["seq"][i][:len(seq)] = seq + + def diverse_beam_search(self, input_dict): + + def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash): + local_time = t - divm + unaug_logprob = logprob.clone() + + if divm > 0: + change = torch.zeros(logprob.size(-1)) + for prev_choice in range(divm): + prev_decisions = seq_table[prev_choice][..., local_time] + for prev_labels in range(bdash): + change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1)) + + change = change.to(logprob.device) + logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda + + return logprob, unaug_logprob + + output = self.prepare_output(input_dict) + group_size = input_dict["group_size"] + batch_size = output["seq"].size(0) + beam_size = input_dict["beam_size"] + bdash = beam_size // group_size + input_dict["bdash"] = bdash + diversity_lambda = input_dict["diversity_lambda"] + device = input_dict["fc_emb"].device + max_length = input_dict["max_length"] + temp = input_dict["temp"] + group_nbest = input_dict["group_nbest"] + batch_size, max_length = output["seq"].size() + if group_nbest: + output["seq"] = torch.full((batch_size, beam_size, max_length), + self.end_idx, dtype=torch.long) + else: + output["seq"] = torch.full((batch_size, group_size, max_length), + self.end_idx, dtype=torch.long) + + + for i in range(batch_size): + input_dict["sample_idx"] = i + seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0] + logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)] + done_beams_table = [[] for _ in range(group_size)] + + output_i = { + "prev_words_beam": [None for _ in range(group_size)], + "next_word": [None for _ in range(group_size)], + "state": [None for _ in range(group_size)] + } + + for t in range(max_length + group_size - 1): + input_dict["t"] = t + for divm in range(group_size): + input_dict["divm"] = divm + if t >= divm and t <= max_length + divm - 1: + local_time = t - divm + decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i) + output_t = self.decoder(decoder_input) + output_t["divm"] = divm + logit_t = output_t["logit"] + if logit_t.size(1) == 1: + logit_t = logit_t.squeeze(1) + elif logit_t.size(1) > 1: + logit_t = logit_t[:, -1, :] + else: + raise Exception("no logit output") + logprob_t = torch.log_softmax(logit_t, dim=1) + logprob_t = torch.log_softmax(logprob_t / temp, dim=1) + logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash) + logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t + if local_time == 0: # for the first step, all k seq will have the same probs + topk_logprob, topk_words = logprob_t[0].topk( + bdash, 0, True, True) + else: # unroll and find top logprob, and their unrolled indices + topk_logprob, topk_words = logprob_t.view(-1).topk( + bdash, 0, True, True) + topk_words = topk_words.cpu() + logprob_table[divm] = topk_logprob + output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,] + output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,] + if local_time > 0: + seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]] + seq_table[divm] = torch.cat([ + seq_table[divm], + output_i["next_word"][divm].unsqueeze(-1)], -1) + + is_end = seq_table[divm][:, t-divm] == self.end_idx + assert seq_table[divm].shape[-1] == t - divm + 1 + if t == max_length + divm - 1: + is_end.fill_(1) + for beam_idx in range(bdash): + if is_end[beam_idx]: + final_beam = { + "seq": seq_table[divm][beam_idx].clone(), + "score": logprob_table[divm][beam_idx].item() + } + final_beam["score"] = final_beam["score"] / (t - divm + 1) + done_beams_table[divm].append(final_beam) + logprob_table[divm][is_end] -= 1000 + self.dbs_process_step(output_i, output_t) + done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)] + if group_nbest: + done_beams = sum(done_beams_table, []) + else: + done_beams = [group_beam[0] for group_beam in done_beams_table] + for _, done_beam in enumerate(done_beams): + output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"] + + return output + + def prepare_dbs_decoder_input(self, input_dict, output_i): + raise NotImplementedError + + def dbs_process_step(self, output_i, output_t): + pass + + +class CaptionSequenceModel(nn.Module): + + def __init__(self, model, seq_output_size): + super().__init__() + self.model = model + if model.decoder.d_model != seq_output_size: + self.output_transform = nn.Linear(model.decoder.d_model, seq_output_size) + else: + self.output_transform = lambda x: x + + def forward(self, input_dict): + output = self.model(input_dict) + + if input_dict["mode"] == "train": + lens = input_dict["cap_len"] - 1 + # seq_outputs: [N, d_model] + elif input_dict["mode"] == "inference": + if "sample_method" in input_dict and input_dict["sample_method"] == "beam": + return output + seq = output["seq"] + lens = torch.where(seq == self.model.end_idx, torch.zeros_like(seq), torch.ones_like(seq)).sum(dim=1) + else: + raise Exception("mode should be either 'train' or 'inference'") + seq_output = mean_with_lens(output["embed"], lens) + seq_output = self.output_transform(seq_output) + output["seq_output"] = seq_output + return output + diff --git a/audio_to_text/captioning/models/decoder.py b/audio_to_text/captioning/models/decoder.py new file mode 100644 index 0000000..869eac1 --- /dev/null +++ b/audio_to_text/captioning/models/decoder.py @@ -0,0 +1,746 @@ +# -*- coding: utf-8 -*- + +import math +from functools import partial + +import numpy as np +import torch +import torch.nn as nn + +from .utils import generate_length_mask, init, PositionalEncoding + + +class BaseDecoder(nn.Module): + """ + Take word/audio embeddings and output the next word probs + Base decoder, cannot be called directly + All decoders should inherit from this class + """ + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, + attn_emb_dim, dropout=0.2): + super().__init__() + self.emb_dim = emb_dim + self.vocab_size = vocab_size + self.fc_emb_dim = fc_emb_dim + self.attn_emb_dim = attn_emb_dim + self.word_embedding = nn.Embedding(vocab_size, emb_dim) + self.in_dropout = nn.Dropout(dropout) + + def forward(self, x): + raise NotImplementedError + + def load_word_embedding(self, weight, freeze=True): + embedding = np.load(weight) + assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch" + assert embedding.shape[1] == self.emb_dim, "embed size mismatch" + + # embeddings = torch.as_tensor(embeddings).float() + # self.word_embeddings.weight = nn.Parameter(embeddings) + # for para in self.word_embeddings.parameters(): + # para.requires_grad = tune + self.word_embedding = nn.Embedding.from_pretrained(embedding, + freeze=freeze) + + +class RnnDecoder(BaseDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs): + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout,) + self.d_model = d_model + self.num_layers = kwargs.get('num_layers', 1) + self.bidirectional = kwargs.get('bidirectional', False) + self.rnn_type = kwargs.get('rnn_type', "GRU") + self.classifier = nn.Linear( + self.d_model * (self.bidirectional + 1), vocab_size) + + def forward(self, x): + raise NotImplementedError + + def init_hidden(self, bs, device): + num_dire = self.bidirectional + 1 + n_layer = self.num_layers + hid_dim = self.d_model + if self.rnn_type == "LSTM": + return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device), + torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)) + else: + return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device) + + +class RnnFcDecoder(RnnDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs): + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs) + self.model = getattr(nn, self.rnn_type)( + input_size=self.emb_dim * 2, + hidden_size=self.d_model, + batch_first=True, + num_layers=self.num_layers, + bidirectional=self.bidirectional) + self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim) + self.apply(init) + + def forward(self, input_dict): + word = input_dict["word"] + state = input_dict.get("state", None) + fc_emb = input_dict["fc_emb"] + + word = word.to(fc_emb.device) + embed = self.in_dropout(self.word_embedding(word)) + + p_fc_emb = self.fc_proj(fc_emb) + # embed: [N, T, embed_size] + embed = torch.cat((embed, p_fc_emb), dim=-1) + + out, state = self.model(embed, state) + # out: [N, T, hs], states: [num_layers * num_dire, N, hs] + logits = self.classifier(out) + output = { + "state": state, + "embeds": out, + "logits": logits + } + + return output + + +class Seq2SeqAttention(nn.Module): + + def __init__(self, hs_enc, hs_dec, attn_size): + """ + Args: + hs_enc: encoder hidden size + hs_dec: decoder hidden size + attn_size: attention vector size + """ + super(Seq2SeqAttention, self).__init__() + self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size) + self.v = nn.Parameter(torch.randn(attn_size)) + self.apply(init) + + def forward(self, h_dec, h_enc, src_lens): + """ + Args: + h_dec: decoder hidden (query), [N, hs_dec] + h_enc: encoder memory (key/value), [N, src_max_len, hs_enc] + src_lens: source (encoder memory) lengths, [N, ] + """ + N = h_enc.size(0) + src_max_len = h_enc.size(1) + h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec] + + attn_input = torch.cat((h_dec, h_enc), dim=-1) + attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size] + + v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size] + score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len] + + idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len) + mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device) + + score = score.masked_fill(mask == 0, -1e10) + weights = torch.softmax(score, dim=-1) # [N, src_max_len] + ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc] + + return ctx, weights + + +class AttentionProj(nn.Module): + + def __init__(self, hs_enc, hs_dec, embed_dim, attn_size): + self.q_proj = nn.Linear(hs_dec, embed_dim) + self.kv_proj = nn.Linear(hs_enc, embed_dim) + self.h2attn = nn.Linear(embed_dim * 2, attn_size) + self.v = nn.Parameter(torch.randn(attn_size)) + self.apply(init) + + def init(self, m): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, h_dec, h_enc, src_lens): + """ + Args: + h_dec: decoder hidden (query), [N, hs_dec] + h_enc: encoder memory (key/value), [N, src_max_len, hs_enc] + src_lens: source (encoder memory) lengths, [N, ] + """ + h_enc = self.kv_proj(h_enc) # [N, src_max_len, embed_dim] + h_dec = self.q_proj(h_dec) # [N, embed_dim] + N = h_enc.size(0) + src_max_len = h_enc.size(1) + h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec] + + attn_input = torch.cat((h_dec, h_enc), dim=-1) + attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size] + + v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size] + score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len] + + idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len) + mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device) + + score = score.masked_fill(mask == 0, -1e10) + weights = torch.softmax(score, dim=-1) # [N, src_max_len] + ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc] + + return ctx, weights + + +class BahAttnDecoder(RnnDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs): + """ + concatenate fc, attn, word to feed to the rnn + """ + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs) + attn_size = kwargs.get("attn_size", self.d_model) + self.model = getattr(nn, self.rnn_type)( + input_size=self.emb_dim * 3, + hidden_size=self.d_model, + batch_first=True, + num_layers=self.num_layers, + bidirectional=self.bidirectional) + self.attn = Seq2SeqAttention(self.attn_emb_dim, + self.d_model * (self.bidirectional + 1) * \ + self.num_layers, + attn_size) + self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim) + self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) + self.apply(init) + + def forward(self, input_dict): + word = input_dict["word"] + state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] + fc_emb = input_dict["fc_emb"] + attn_emb = input_dict["attn_emb"] + attn_emb_len = input_dict["attn_emb_len"] + + word = word.to(fc_emb.device) + embed = self.in_dropout(self.word_embedding(word)) + + # embed: [N, 1, embed_size] + if state is None: + state = self.init_hidden(word.size(0), fc_emb.device) + if self.rnn_type == "LSTM": + query = state[0].transpose(0, 1).flatten(1) + else: + query = state.transpose(0, 1).flatten(1) + c, attn_weight = self.attn(query, attn_emb, attn_emb_len) + + p_fc_emb = self.fc_proj(fc_emb) + p_ctx = self.ctx_proj(c) + rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)), + dim=-1) + + out, state = self.model(rnn_input, state) + + output = { + "state": state, + "embed": out, + "logit": self.classifier(out), + "attn_weight": attn_weight + } + return output + + +class BahAttnDecoder2(RnnDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs): + """ + add fc, attn, word together to feed to the rnn + """ + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs) + attn_size = kwargs.get("attn_size", self.d_model) + self.model = getattr(nn, self.rnn_type)( + input_size=self.emb_dim, + hidden_size=self.d_model, + batch_first=True, + num_layers=self.num_layers, + bidirectional=self.bidirectional) + self.attn = Seq2SeqAttention(self.emb_dim, + self.d_model * (self.bidirectional + 1) * \ + self.num_layers, + attn_size) + self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim) + self.attn_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) + self.apply(partial(init, method="xavier")) + + def forward(self, input_dict): + word = input_dict["word"] + state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] + fc_emb = input_dict["fc_emb"] + attn_emb = input_dict["attn_emb"] + attn_emb_len = input_dict["attn_emb_len"] + + word = word.to(fc_emb.device) + embed = self.in_dropout(self.word_embedding(word)) + p_attn_emb = self.attn_proj(attn_emb) + + # embed: [N, 1, embed_size] + if state is None: + state = self.init_hidden(word.size(0), fc_emb.device) + if self.rnn_type == "LSTM": + query = state[0].transpose(0, 1).flatten(1) + else: + query = state.transpose(0, 1).flatten(1) + c, attn_weight = self.attn(query, p_attn_emb, attn_emb_len) + + p_fc_emb = self.fc_proj(fc_emb) + rnn_input = embed + c.unsqueeze(1) + p_fc_emb.unsqueeze(1) + + out, state = self.model(rnn_input, state) + + output = { + "state": state, + "embed": out, + "logit": self.classifier(out), + "attn_weight": attn_weight + } + return output + + +class ConditionalBahAttnDecoder(RnnDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs): + """ + concatenate fc, attn, word to feed to the rnn + """ + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs) + attn_size = kwargs.get("attn_size", self.d_model) + self.model = getattr(nn, self.rnn_type)( + input_size=self.emb_dim * 3, + hidden_size=self.d_model, + batch_first=True, + num_layers=self.num_layers, + bidirectional=self.bidirectional) + self.attn = Seq2SeqAttention(self.attn_emb_dim, + self.d_model * (self.bidirectional + 1) * \ + self.num_layers, + attn_size) + self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) + self.condition_embedding = nn.Embedding(2, emb_dim) + self.apply(init) + + def forward(self, input_dict): + word = input_dict["word"] + state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] + fc_emb = input_dict["fc_emb"] + attn_emb = input_dict["attn_emb"] + attn_emb_len = input_dict["attn_emb_len"] + condition = input_dict["condition"] + + word = word.to(fc_emb.device) + embed = self.in_dropout(self.word_embedding(word)) + + condition = torch.as_tensor([[1 - c, c] for c in condition]).to(fc_emb.device) + condition_emb = torch.matmul(condition, self.condition_embedding.weight) + # condition_embs: [N, emb_dim] + + # embed: [N, 1, embed_size] + if state is None: + state = self.init_hidden(word.size(0), fc_emb.device) + if self.rnn_type == "LSTM": + query = state[0].transpose(0, 1).flatten(1) + else: + query = state.transpose(0, 1).flatten(1) + c, attn_weight = self.attn(query, attn_emb, attn_emb_len) + + p_ctx = self.ctx_proj(c) + rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), condition_emb.unsqueeze(1)), + dim=-1) + + out, state = self.model(rnn_input, state) + + output = { + "state": state, + "embed": out, + "logit": self.classifier(out), + "attn_weight": attn_weight + } + return output + + +class StructBahAttnDecoder(RnnDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, struct_vocab_size, + attn_emb_dim, dropout, d_model, **kwargs): + """ + concatenate fc, attn, word to feed to the rnn + """ + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs) + attn_size = kwargs.get("attn_size", self.d_model) + self.model = getattr(nn, self.rnn_type)( + input_size=self.emb_dim * 3, + hidden_size=self.d_model, + batch_first=True, + num_layers=self.num_layers, + bidirectional=self.bidirectional) + self.attn = Seq2SeqAttention(self.attn_emb_dim, + self.d_model * (self.bidirectional + 1) * \ + self.num_layers, + attn_size) + self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) + self.struct_embedding = nn.Embedding(struct_vocab_size, emb_dim) + self.apply(init) + + def forward(self, input_dict): + word = input_dict["word"] + state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] + fc_emb = input_dict["fc_emb"] + attn_emb = input_dict["attn_emb"] + attn_emb_len = input_dict["attn_emb_len"] + structure = input_dict["structure"] + + word = word.to(fc_emb.device) + embed = self.in_dropout(self.word_embedding(word)) + + struct_emb = self.struct_embedding(structure) + # struct_embs: [N, emb_dim] + + # embed: [N, 1, embed_size] + if state is None: + state = self.init_hidden(word.size(0), fc_emb.device) + if self.rnn_type == "LSTM": + query = state[0].transpose(0, 1).flatten(1) + else: + query = state.transpose(0, 1).flatten(1) + c, attn_weight = self.attn(query, attn_emb, attn_emb_len) + + p_ctx = self.ctx_proj(c) + rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), struct_emb.unsqueeze(1)), dim=-1) + + out, state = self.model(rnn_input, state) + + output = { + "state": state, + "embed": out, + "logit": self.classifier(out), + "attn_weight": attn_weight + } + return output + + +class StyleBahAttnDecoder(RnnDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs): + """ + concatenate fc, attn, word to feed to the rnn + """ + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs) + attn_size = kwargs.get("attn_size", self.d_model) + self.model = getattr(nn, self.rnn_type)( + input_size=self.emb_dim * 3, + hidden_size=self.d_model, + batch_first=True, + num_layers=self.num_layers, + bidirectional=self.bidirectional) + self.attn = Seq2SeqAttention(self.attn_emb_dim, + self.d_model * (self.bidirectional + 1) * \ + self.num_layers, + attn_size) + self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) + self.apply(init) + + def forward(self, input_dict): + word = input_dict["word"] + state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] + fc_emb = input_dict["fc_emb"] + attn_emb = input_dict["attn_emb"] + attn_emb_len = input_dict["attn_emb_len"] + style = input_dict["style"] + + word = word.to(fc_emb.device) + embed = self.in_dropout(self.word_embedding(word)) + + # embed: [N, 1, embed_size] + if state is None: + state = self.init_hidden(word.size(0), fc_emb.device) + if self.rnn_type == "LSTM": + query = state[0].transpose(0, 1).flatten(1) + else: + query = state.transpose(0, 1).flatten(1) + c, attn_weight = self.attn(query, attn_emb, attn_emb_len) + + p_ctx = self.ctx_proj(c) + rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), style.unsqueeze(1)), + dim=-1) + + out, state = self.model(rnn_input, state) + + output = { + "state": state, + "embed": out, + "logit": self.classifier(out), + "attn_weight": attn_weight + } + return output + + +class BahAttnDecoder3(RnnDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs): + """ + concatenate fc, attn, word to feed to the rnn + """ + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs) + attn_size = kwargs.get("attn_size", self.d_model) + self.model = getattr(nn, self.rnn_type)( + input_size=self.emb_dim + attn_emb_dim, + hidden_size=self.d_model, + batch_first=True, + num_layers=self.num_layers, + bidirectional=self.bidirectional) + self.attn = Seq2SeqAttention(self.attn_emb_dim, + self.d_model * (self.bidirectional + 1) * \ + self.num_layers, + attn_size) + self.ctx_proj = lambda x: x + self.apply(init) + + def forward(self, input_dict): + word = input_dict["word"] + state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] + fc_emb = input_dict["fc_emb"] + attn_emb = input_dict["attn_emb"] + attn_emb_len = input_dict["attn_emb_len"] + + if word.size(-1) == self.fc_emb_dim: # fc_emb + embed = word.unsqueeze(1) + elif word.size(-1) == 1: # word + word = word.to(fc_emb.device) + embed = self.in_dropout(self.word_embedding(word)) + else: + raise Exception(f"problem with word input size {word.size()}") + + # embed: [N, 1, embed_size] + if state is None: + state = self.init_hidden(word.size(0), fc_emb.device) + if self.rnn_type == "LSTM": + query = state[0].transpose(0, 1).flatten(1) + else: + query = state.transpose(0, 1).flatten(1) + c, attn_weight = self.attn(query, attn_emb, attn_emb_len) + + p_ctx = self.ctx_proj(c) + rnn_input = torch.cat((embed, p_ctx.unsqueeze(1)), dim=-1) + + out, state = self.model(rnn_input, state) + + output = { + "state": state, + "embed": out, + "logit": self.classifier(out), + "attn_weight": attn_weight + } + return output + + +class SpecificityBahAttnDecoder(RnnDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs): + """ + concatenate fc, attn, word to feed to the rnn + """ + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, d_model, **kwargs) + attn_size = kwargs.get("attn_size", self.d_model) + self.model = getattr(nn, self.rnn_type)( + input_size=self.emb_dim + attn_emb_dim + 1, + hidden_size=self.d_model, + batch_first=True, + num_layers=self.num_layers, + bidirectional=self.bidirectional) + self.attn = Seq2SeqAttention(self.attn_emb_dim, + self.d_model * (self.bidirectional + 1) * \ + self.num_layers, + attn_size) + self.ctx_proj = lambda x: x + self.apply(init) + + def forward(self, input_dict): + word = input_dict["word"] + state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] + fc_emb = input_dict["fc_emb"] + attn_emb = input_dict["attn_emb"] + attn_emb_len = input_dict["attn_emb_len"] + condition = input_dict["condition"] # [N,] + + word = word.to(fc_emb.device) + embed = self.in_dropout(self.word_embedding(word)) + + # embed: [N, 1, embed_size] + if state is None: + state = self.init_hidden(word.size(0), fc_emb.device) + if self.rnn_type == "LSTM": + query = state[0].transpose(0, 1).flatten(1) + else: + query = state.transpose(0, 1).flatten(1) + c, attn_weight = self.attn(query, attn_emb, attn_emb_len) + + p_ctx = self.ctx_proj(c) + rnn_input = torch.cat( + (embed, p_ctx.unsqueeze(1), condition.reshape(-1, 1, 1)), + dim=-1) + + out, state = self.model(rnn_input, state) + + output = { + "state": state, + "embed": out, + "logit": self.classifier(out), + "attn_weight": attn_weight + } + return output + + +class TransformerDecoder(BaseDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, **kwargs): + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout=dropout,) + self.d_model = emb_dim + self.nhead = kwargs.get("nhead", self.d_model // 64) + self.nlayers = kwargs.get("nlayers", 2) + self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) + + self.pos_encoder = PositionalEncoding(self.d_model, dropout) + layer = nn.TransformerDecoderLayer(d_model=self.d_model, + nhead=self.nhead, + dim_feedforward=self.dim_feedforward, + dropout=dropout) + self.model = nn.TransformerDecoder(layer, self.nlayers) + self.classifier = nn.Linear(self.d_model, vocab_size) + self.attn_proj = nn.Sequential( + nn.Linear(self.attn_emb_dim, self.d_model), + nn.ReLU(), + nn.Dropout(dropout), + nn.LayerNorm(self.d_model) + ) + # self.attn_proj = lambda x: x + self.init_params() + + def init_params(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def generate_square_subsequent_mask(self, max_length): + mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + def forward(self, input_dict): + word = input_dict["word"] + attn_emb = input_dict["attn_emb"] + attn_emb_len = input_dict["attn_emb_len"] + cap_padding_mask = input_dict["cap_padding_mask"] + + p_attn_emb = self.attn_proj(attn_emb) + p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim] + word = word.to(attn_emb.device) + embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim] + embed = embed.transpose(0, 1) # [T, N, emb_dim] + embed = self.pos_encoder(embed) + + tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) + memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) + output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, + tgt_key_padding_mask=cap_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + output = output.transpose(0, 1) + output = { + "embed": output, + "logit": self.classifier(output), + } + return output + + + + +class EventTransformerDecoder(TransformerDecoder): + + def forward(self, input_dict): + word = input_dict["word"] # index of word embeddings + attn_emb = input_dict["attn_emb"] + attn_emb_len = input_dict["attn_emb_len"] + cap_padding_mask = input_dict["cap_padding_mask"] + event_emb = input_dict["event"] # [N, emb_dim] + + p_attn_emb = self.attn_proj(attn_emb) + p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim] + word = word.to(attn_emb.device) + embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim] + + embed = embed.transpose(0, 1) # [T, N, emb_dim] + embed += event_emb + embed = self.pos_encoder(embed) + + tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) + memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) + output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, + tgt_key_padding_mask=cap_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + output = output.transpose(0, 1) + output = { + "embed": output, + "logit": self.classifier(output), + } + return output + + +class KeywordProbTransformerDecoder(TransformerDecoder): + + def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, keyword_classes_num, **kwargs): + super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, + dropout, **kwargs) + self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model) + self.word_keyword_norm = nn.LayerNorm(self.d_model) + + def forward(self, input_dict): + word = input_dict["word"] # index of word embeddings + attn_emb = input_dict["attn_emb"] + attn_emb_len = input_dict["attn_emb_len"] + cap_padding_mask = input_dict["cap_padding_mask"] + keyword = input_dict["keyword"] # [N, keyword_classes_num] + + p_attn_emb = self.attn_proj(attn_emb) + p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim] + word = word.to(attn_emb.device) + embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim] + + embed = embed.transpose(0, 1) # [T, N, emb_dim] + embed += self.keyword_proj(keyword) + embed = self.word_keyword_norm(embed) + + embed = self.pos_encoder(embed) + + tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) + memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) + output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, + tgt_key_padding_mask=cap_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + output = output.transpose(0, 1) + output = { + "embed": output, + "logit": self.classifier(output), + } + return output diff --git a/audio_to_text/captioning/models/encoder.py b/audio_to_text/captioning/models/encoder.py new file mode 100644 index 0000000..0d6d8e8 --- /dev/null +++ b/audio_to_text/captioning/models/encoder.py @@ -0,0 +1,686 @@ +# -*- coding: utf-8 -*- + +import math +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchaudio import transforms +from torchlibrosa.augmentation import SpecAugmentation + +from .utils import mean_with_lens, max_with_lens, \ + init, pack_wrapper, generate_length_mask, PositionalEncoding + + +def init_layer(layer): + """Initialize a Linear or Convolutional layer. """ + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, 'bias'): + if layer.bias is not None: + layer.bias.data.fill_(0.) + + +def init_bn(bn): + """Initialize a Batchnorm layer. """ + bn.bias.data.fill_(0.) + bn.weight.data.fill_(1.) + + +class BaseEncoder(nn.Module): + + """ + Encode the given audio into embedding + Base encoder class, cannot be called directly + All encoders should inherit from this class + """ + + def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim): + super(BaseEncoder, self).__init__() + self.spec_dim = spec_dim + self.fc_feat_dim = fc_feat_dim + self.attn_feat_dim = attn_feat_dim + + + def forward(self, x): + ######################### + # an encoder first encodes audio feature into embedding, obtaining + # `encoded`: { + # fc_embs: [N, fc_emb_dim], + # attn_embs: [N, attn_max_len, attn_emb_dim], + # attn_emb_lens: [N,] + # } + ######################### + raise NotImplementedError + + +class Block2D(nn.Module): + + def __init__(self, cin, cout, kernel_size=3, padding=1): + super().__init__() + self.block = nn.Sequential( + nn.BatchNorm2d(cin), + nn.Conv2d(cin, + cout, + kernel_size=kernel_size, + padding=padding, + bias=False), + nn.LeakyReLU(inplace=True, negative_slope=0.1)) + + def forward(self, x): + return self.block(x) + + +class LinearSoftPool(nn.Module): + """LinearSoftPool + Linear softmax, takes logits and returns a probability, near to the actual maximum value. + Taken from the paper: + A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling + https://arxiv.org/abs/1810.09050 + """ + def __init__(self, pooldim=1): + super().__init__() + self.pooldim = pooldim + + def forward(self, logits, time_decision): + return (time_decision**2).sum(self.pooldim) / time_decision.sum( + self.pooldim) + + +class MeanPool(nn.Module): + + def __init__(self, pooldim=1): + super().__init__() + self.pooldim = pooldim + + def forward(self, logits, decision): + return torch.mean(decision, dim=self.pooldim) + + +class AttentionPool(nn.Module): + """docstring for AttentionPool""" + def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs): + super().__init__() + self.inputdim = inputdim + self.outputdim = outputdim + self.pooldim = pooldim + self.transform = nn.Linear(inputdim, outputdim) + self.activ = nn.Softmax(dim=self.pooldim) + self.eps = 1e-7 + + def forward(self, logits, decision): + # Input is (B, T, D) + # B, T, D + w = self.activ(torch.clamp(self.transform(logits), -15, 15)) + detect = (decision * w).sum( + self.pooldim) / (w.sum(self.pooldim) + self.eps) + # B, T, D + return detect + + +class MMPool(nn.Module): + + def __init__(self, dims): + super().__init__() + self.avgpool = nn.AvgPool2d(dims) + self.maxpool = nn.MaxPool2d(dims) + + def forward(self, x): + return self.avgpool(x) + self.maxpool(x) + + +def parse_poolingfunction(poolingfunction_name='mean', **kwargs): + """parse_poolingfunction + A heler function to parse any temporal pooling + Pooling is done on dimension 1 + :param poolingfunction_name: + :param **kwargs: + """ + poolingfunction_name = poolingfunction_name.lower() + if poolingfunction_name == 'mean': + return MeanPool(pooldim=1) + elif poolingfunction_name == 'linear': + return LinearSoftPool(pooldim=1) + elif poolingfunction_name == 'attention': + return AttentionPool(inputdim=kwargs['inputdim'], + outputdim=kwargs['outputdim']) + + +def embedding_pooling(x, lens, pooling="mean"): + if pooling == "max": + fc_embs = max_with_lens(x, lens) + elif pooling == "mean": + fc_embs = mean_with_lens(x, lens) + elif pooling == "mean+max": + x_mean = mean_with_lens(x, lens) + x_max = max_with_lens(x, lens) + fc_embs = x_mean + x_max + elif pooling == "last": + indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1)) + # indices: [N, 1, hidden] + fc_embs = torch.gather(x, 1, indices).squeeze(1) + else: + raise Exception(f"pooling method {pooling} not support") + return fc_embs + + +class Cdur5Encoder(BaseEncoder): + + def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"): + super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) + self.pooling = pooling + self.features = nn.Sequential( + Block2D(1, 32), + nn.LPPool2d(4, (2, 4)), + Block2D(32, 128), + Block2D(128, 128), + nn.LPPool2d(4, (2, 4)), + Block2D(128, 128), + Block2D(128, 128), + nn.LPPool2d(4, (1, 4)), + nn.Dropout(0.3), + ) + with torch.no_grad(): + rnn_input_dim = self.features( + torch.randn(1, 1, 500, spec_dim)).shape + rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] + + self.gru = nn.GRU(rnn_input_dim, + 128, + bidirectional=True, + batch_first=True) + self.apply(init) + + def forward(self, input_dict): + x = input_dict["spec"] + lens = input_dict["spec_len"] + if "upsample" not in input_dict: + input_dict["upsample"] = False + lens = torch.as_tensor(copy.deepcopy(lens)) + N, T, _ = x.shape + x = x.unsqueeze(1) + x = self.features(x) + x = x.transpose(1, 2).contiguous().flatten(-2) + x, _ = self.gru(x) + if input_dict["upsample"]: + x = nn.functional.interpolate( + x.transpose(1, 2), + T, + mode='linear', + align_corners=False).transpose(1, 2) + else: + lens //= 4 + attn_emb = x + fc_emb = embedding_pooling(x, lens, self.pooling) + return { + "attn_emb": attn_emb, + "fc_emb": fc_emb, + "attn_emb_len": lens + } + + +def conv_conv_block(in_channel, out_channel): + return nn.Sequential( + nn.Conv2d(in_channel, + out_channel, + kernel_size=3, + bias=False, + padding=1), + nn.BatchNorm2d(out_channel), + nn.ReLU(True), + nn.Conv2d(out_channel, + out_channel, + kernel_size=3, + bias=False, + padding=1), + nn.BatchNorm2d(out_channel), + nn.ReLU(True) + ) + + +class Cdur8Encoder(BaseEncoder): + + def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"): + super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) + self.pooling = pooling + self.features = nn.Sequential( + conv_conv_block(1, 64), + MMPool((2, 2)), + nn.Dropout(0.2, True), + conv_conv_block(64, 128), + MMPool((2, 2)), + nn.Dropout(0.2, True), + conv_conv_block(128, 256), + MMPool((1, 2)), + nn.Dropout(0.2, True), + conv_conv_block(256, 512), + MMPool((1, 2)), + nn.Dropout(0.2, True), + nn.AdaptiveAvgPool2d((None, 1)), + ) + self.init_bn = nn.BatchNorm2d(spec_dim) + self.embedding = nn.Linear(512, 512) + self.gru = nn.GRU(512, 256, bidirectional=True, batch_first=True) + self.apply(init) + + def forward(self, input_dict): + x = input_dict["spec"] + lens = input_dict["spec_len"] + lens = torch.as_tensor(copy.deepcopy(lens)) + x = x.unsqueeze(1) # B x 1 x T x D + x = x.transpose(1, 3) + x = self.init_bn(x) + x = x.transpose(1, 3) + x = self.features(x) + x = x.transpose(1, 2).contiguous().flatten(-2) + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.embedding(x)) + x, _ = self.gru(x) + attn_emb = x + lens //= 4 + fc_emb = embedding_pooling(x, lens, self.pooling) + return { + "attn_emb": attn_emb, + "fc_emb": fc_emb, + "attn_emb_len": lens + } + + +class Cnn10Encoder(BaseEncoder): + + def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim): + super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) + self.features = nn.Sequential( + conv_conv_block(1, 64), + nn.AvgPool2d((2, 2)), + nn.Dropout(0.2, True), + conv_conv_block(64, 128), + nn.AvgPool2d((2, 2)), + nn.Dropout(0.2, True), + conv_conv_block(128, 256), + nn.AvgPool2d((2, 2)), + nn.Dropout(0.2, True), + conv_conv_block(256, 512), + nn.AvgPool2d((2, 2)), + nn.Dropout(0.2, True), + nn.AdaptiveAvgPool2d((None, 1)), + ) + self.init_bn = nn.BatchNorm2d(spec_dim) + self.embedding = nn.Linear(512, 512) + self.apply(init) + + def forward(self, input_dict): + x = input_dict["spec"] + lens = input_dict["spec_len"] + lens = torch.as_tensor(copy.deepcopy(lens)) + x = x.unsqueeze(1) # [N, 1, T, D] + x = x.transpose(1, 3) + x = self.init_bn(x) + x = x.transpose(1, 3) + x = self.features(x) # [N, 512, T/16, 1] + x = x.transpose(1, 2).contiguous().flatten(-2) # [N, T/16, 512] + attn_emb = x + lens //= 16 + fc_emb = embedding_pooling(x, lens, "mean+max") + fc_emb = F.dropout(fc_emb, p=0.5, training=self.training) + fc_emb = self.embedding(fc_emb) + fc_emb = F.relu_(fc_emb) + return { + "attn_emb": attn_emb, + "fc_emb": fc_emb, + "attn_emb_len": lens + } + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), stride=(1, 1), + padding=(1, 1), bias=False) + + self.conv2 = nn.Conv2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), stride=(1, 1), + padding=(1, 1), bias=False) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + + def forward(self, input, pool_size=(2, 2), pool_type='avg'): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_type == 'max': + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg': + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg+max': + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception('Incorrect argument!') + + return x + + +class Cnn14Encoder(nn.Module): + def __init__(self, sample_rate=32000): + super().__init__() + sr_to_fmax = { + 32000: 14000, + 16000: 8000 + } + # Logmel spectrogram extractor + self.melspec_extractor = transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=32 * sample_rate // 1000, + win_length=32 * sample_rate // 1000, + hop_length=10 * sample_rate // 1000, + f_min=50, + f_max=sr_to_fmax[sample_rate], + n_mels=64, + norm="slaney", + mel_scale="slaney" + ) + self.hop_length = 10 * sample_rate // 1000 + self.db_transform = transforms.AmplitudeToDB() + # Spec augmenter + self.spec_augmenter = SpecAugmentation(time_drop_width=64, + time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.downsample_ratio = 32 + + self.fc1 = nn.Linear(2048, 2048, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + + def load_pretrained(self, pretrained): + checkpoint = torch.load(pretrained, map_location="cpu") + + if "model" in checkpoint: + state_keys = checkpoint["model"].keys() + backbone = False + for key in state_keys: + if key.startswith("backbone."): + backbone = True + break + + if backbone: # COLA + state_dict = {} + for key, value in checkpoint["model"].items(): + if key.startswith("backbone."): + model_key = key.replace("backbone.", "") + state_dict[model_key] = value + else: # PANNs + state_dict = checkpoint["model"] + elif "state_dict" in checkpoint: # CLAP + state_dict = checkpoint["state_dict"] + state_dict_keys = list(filter( + lambda x: "audio_encoder" in x, state_dict.keys())) + state_dict = { + key.replace('audio_encoder.', ''): state_dict[key] + for key in state_dict_keys + } + else: + raise Exception("Unkown checkpoint format") + + model_dict = self.state_dict() + pretrained_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and ( + model_dict[k].shape == v.shape) + } + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict, strict=True) + + def forward(self, input_dict): + """ + Input: (batch_size, n_samples)""" + waveform = input_dict["wav"] + wave_length = input_dict["wav_len"] + specaug = input_dict["specaug"] + x = self.melspec_extractor(waveform) + x = self.db_transform(x) # (batch_size, mel_bins, time_steps) + x = x.transpose(1, 2) + x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins) + + # SpecAugment + if self.training and specaug: + x = self.spec_augmenter(x) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + attn_emb = x.transpose(1, 2) + + wave_length = torch.as_tensor(wave_length) + feat_length = torch.div(wave_length, self.hop_length, + rounding_mode="floor") + 1 + feat_length = torch.div(feat_length, self.downsample_ratio, + rounding_mode="floor") + x_max = max_with_lens(attn_emb, feat_length) + x_mean = mean_with_lens(attn_emb, feat_length) + x = x_max + x_mean + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + fc_emb = F.dropout(x, p=0.5, training=self.training) + + output_dict = { + 'fc_emb': fc_emb, + 'attn_emb': attn_emb, + 'attn_emb_len': feat_length + } + + return output_dict + + +class RnnEncoder(BaseEncoder): + + def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, + pooling="mean", **kwargs): + super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) + self.pooling = pooling + self.hidden_size = kwargs.get('hidden_size', 512) + self.bidirectional = kwargs.get('bidirectional', False) + self.num_layers = kwargs.get('num_layers', 1) + self.dropout = kwargs.get('dropout', 0.2) + self.rnn_type = kwargs.get('rnn_type', "GRU") + self.in_bn = kwargs.get('in_bn', False) + self.embed_dim = self.hidden_size * (self.bidirectional + 1) + self.network = getattr(nn, self.rnn_type)( + attn_feat_dim, + self.hidden_size, + num_layers=self.num_layers, + bidirectional=self.bidirectional, + dropout=self.dropout, + batch_first=True) + if self.in_bn: + self.bn = nn.BatchNorm1d(self.embed_dim) + self.apply(init) + + def forward(self, input_dict): + x = input_dict["attn"] + lens = input_dict["attn_len"] + lens = torch.as_tensor(lens) + # x: [N, T, E] + if self.in_bn: + x = pack_wrapper(self.bn, x, lens) + out = pack_wrapper(self.network, x, lens) + # out: [N, T, hidden] + attn_emb = out + fc_emb = embedding_pooling(out, lens, self.pooling) + return { + "attn_emb": attn_emb, + "fc_emb": fc_emb, + "attn_emb_len": lens + } + + +class Cnn14RnnEncoder(nn.Module): + def __init__(self, sample_rate=32000, pretrained=None, + freeze_cnn=False, freeze_cnn_bn=False, + pooling="mean", **kwargs): + super().__init__() + self.cnn = Cnn14Encoder(sample_rate) + self.rnn = RnnEncoder(64, 2048, 2048, pooling, **kwargs) + if pretrained is not None: + self.cnn.load_pretrained(pretrained) + if freeze_cnn: + assert pretrained is not None, "cnn is not pretrained but frozen" + for param in self.cnn.parameters(): + param.requires_grad = False + self.freeze_cnn_bn = freeze_cnn_bn + + def train(self, mode): + super().train(mode=mode) + if self.freeze_cnn_bn: + def bn_eval(module): + class_name = module.__class__.__name__ + if class_name.find("BatchNorm") != -1: + module.eval() + self.cnn.apply(bn_eval) + return self + + def forward(self, input_dict): + output_dict = self.cnn(input_dict) + output_dict["attn"] = output_dict["attn_emb"] + output_dict["attn_len"] = output_dict["attn_emb_len"] + del output_dict["attn_emb"], output_dict["attn_emb_len"] + output_dict = self.rnn(output_dict) + return output_dict + + +class TransformerEncoder(BaseEncoder): + + def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, d_model, **kwargs): + super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) + self.d_model = d_model + dropout = kwargs.get("dropout", 0.2) + self.nhead = kwargs.get("nhead", self.d_model // 64) + self.nlayers = kwargs.get("nlayers", 2) + self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) + + self.attn_proj = nn.Sequential( + nn.Linear(attn_feat_dim, self.d_model), + nn.ReLU(), + nn.Dropout(dropout), + nn.LayerNorm(self.d_model) + ) + layer = nn.TransformerEncoderLayer(d_model=self.d_model, + nhead=self.nhead, + dim_feedforward=self.dim_feedforward, + dropout=dropout) + self.model = nn.TransformerEncoder(layer, self.nlayers) + self.cls_token = nn.Parameter(torch.zeros(d_model)) + self.init_params() + + def init_params(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, input_dict): + attn_feat = input_dict["attn"] + attn_feat_len = input_dict["attn_len"] + attn_feat_len = torch.as_tensor(attn_feat_len) + + attn_feat = self.attn_proj(attn_feat) # [bs, T, d_model] + + cls_emb = self.cls_token.reshape(1, 1, self.d_model).repeat( + attn_feat.size(0), 1, 1) + attn_feat = torch.cat((cls_emb, attn_feat), dim=1) + attn_feat = attn_feat.transpose(0, 1) + + attn_feat_len += 1 + src_key_padding_mask = ~generate_length_mask( + attn_feat_len, attn_feat.size(0)).to(attn_feat.device) + output = self.model(attn_feat, src_key_padding_mask=src_key_padding_mask) + + attn_emb = output.transpose(0, 1) + fc_emb = attn_emb[:, 0] + return { + "attn_emb": attn_emb, + "fc_emb": fc_emb, + "attn_emb_len": attn_feat_len + } + + +class Cnn14TransformerEncoder(nn.Module): + def __init__(self, sample_rate=32000, pretrained=None, + freeze_cnn=False, freeze_cnn_bn=False, + d_model="mean", **kwargs): + super().__init__() + self.cnn = Cnn14Encoder(sample_rate) + self.trm = TransformerEncoder(64, 2048, 2048, d_model, **kwargs) + if pretrained is not None: + self.cnn.load_pretrained(pretrained) + if freeze_cnn: + assert pretrained is not None, "cnn is not pretrained but frozen" + for param in self.cnn.parameters(): + param.requires_grad = False + self.freeze_cnn_bn = freeze_cnn_bn + + def train(self, mode): + super().train(mode=mode) + if self.freeze_cnn_bn: + def bn_eval(module): + class_name = module.__class__.__name__ + if class_name.find("BatchNorm") != -1: + module.eval() + self.cnn.apply(bn_eval) + return self + + def forward(self, input_dict): + output_dict = self.cnn(input_dict) + output_dict["attn"] = output_dict["attn_emb"] + output_dict["attn_len"] = output_dict["attn_emb_len"] + del output_dict["attn_emb"], output_dict["attn_emb_len"] + output_dict = self.trm(output_dict) + return output_dict + + + + + diff --git a/audio_to_text/captioning/models/transformer_model.py b/audio_to_text/captioning/models/transformer_model.py new file mode 100644 index 0000000..76c97f1 --- /dev/null +++ b/audio_to_text/captioning/models/transformer_model.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +import random +import torch +import torch.nn as nn + +from .base_model import CaptionModel +from .utils import repeat_tensor +import audio_to_text.captioning.models.decoder + + +class TransformerModel(CaptionModel): + + def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): + if not hasattr(self, "compatible_decoders"): + self.compatible_decoders = ( + audio_to_text.captioning.models.decoder.TransformerDecoder, + ) + super().__init__(encoder, decoder, **kwargs) + + def seq_forward(self, input_dict): + cap = input_dict["cap"] + cap_padding_mask = (cap == self.pad_idx).to(cap.device) + cap_padding_mask = cap_padding_mask[:, :-1] + output = self.decoder( + { + "word": cap[:, :-1], + "attn_emb": input_dict["attn_emb"], + "attn_emb_len": input_dict["attn_emb_len"], + "cap_padding_mask": cap_padding_mask + } + ) + return output + + def prepare_decoder_input(self, input_dict, output): + decoder_input = { + "attn_emb": input_dict["attn_emb"], + "attn_emb_len": input_dict["attn_emb_len"] + } + t = input_dict["t"] + + ############### + # determine input word + ################ + if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling + word = input_dict["cap"][:, :t+1] + else: + start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() + if t == 0: + word = start_word + else: + word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) + # word: [N, T] + decoder_input["word"] = word + + cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) + decoder_input["cap_padding_mask"] = cap_padding_mask + return decoder_input + + def prepare_beamsearch_decoder_input(self, input_dict, output_i): + decoder_input = {} + t = input_dict["t"] + i = input_dict["sample_idx"] + beam_size = input_dict["beam_size"] + ############### + # prepare attn embeds + ################ + if t == 0: + attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) + attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size) + output_i["attn_emb"] = attn_emb + output_i["attn_emb_len"] = attn_emb_len + decoder_input["attn_emb"] = output_i["attn_emb"] + decoder_input["attn_emb_len"] = output_i["attn_emb_len"] + ############### + # determine input word + ################ + start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() + if t == 0: + word = start_word + else: + word = torch.cat((start_word, output_i["seq"]), dim=-1) + decoder_input["word"] = word + cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) + decoder_input["cap_padding_mask"] = cap_padding_mask + + return decoder_input + + +class M2TransformerModel(CaptionModel): + + def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): + if not hasattr(self, "compatible_decoders"): + self.compatible_decoders = ( + captioning.models.decoder.M2TransformerDecoder, + ) + super().__init__(encoder, decoder, **kwargs) + self.check_encoder_compatibility() + + def check_encoder_compatibility(self): + assert isinstance(self.encoder, captioning.models.encoder.M2TransformerEncoder), \ + f"only M2TransformerModel is compatible with {self.__class__.__name__}" + + + def seq_forward(self, input_dict): + cap = input_dict["cap"] + output = self.decoder( + { + "word": cap[:, :-1], + "attn_emb": input_dict["attn_emb"], + "attn_emb_mask": input_dict["attn_emb_mask"], + } + ) + return output + + def prepare_decoder_input(self, input_dict, output): + decoder_input = { + "attn_emb": input_dict["attn_emb"], + "attn_emb_mask": input_dict["attn_emb_mask"] + } + t = input_dict["t"] + + ############### + # determine input word + ################ + if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling + word = input_dict["cap"][:, :t+1] + else: + start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() + if t == 0: + word = start_word + else: + word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) + # word: [N, T] + decoder_input["word"] = word + + return decoder_input + + def prepare_beamsearch_decoder_input(self, input_dict, output_i): + decoder_input = {} + t = input_dict["t"] + i = input_dict["sample_idx"] + beam_size = input_dict["beam_size"] + ############### + # prepare attn embeds + ################ + if t == 0: + attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) + attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size) + output_i["attn_emb"] = attn_emb + output_i["attn_emb_mask"] = attn_emb_mask + decoder_input["attn_emb"] = output_i["attn_emb"] + decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"] + ############### + # determine input word + ################ + start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() + if t == 0: + word = start_word + else: + word = torch.cat((start_word, output_i["seq"]), dim=-1) + decoder_input["word"] = word + + return decoder_input + + +class EventEncoder(nn.Module): + """ + Encode the Label information in AudioCaps and AudioSet + """ + def __init__(self, emb_dim, vocab_size=527): + super(EventEncoder, self).__init__() + self.label_embedding = nn.Parameter( + torch.randn((vocab_size, emb_dim)), requires_grad=True) + + def forward(self, word_idxs): + indices = word_idxs / word_idxs.sum(dim=1, keepdim=True) + embeddings = indices @ self.label_embedding + return embeddings + + +class EventCondTransformerModel(TransformerModel): + + def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): + if not hasattr(self, "compatible_decoders"): + self.compatible_decoders = ( + captioning.models.decoder.EventTransformerDecoder, + ) + super().__init__(encoder, decoder, **kwargs) + self.label_encoder = EventEncoder(decoder.emb_dim, 527) + self.train_forward_keys += ["events"] + self.inference_forward_keys += ["events"] + + # def seq_forward(self, input_dict): + # cap = input_dict["cap"] + # cap_padding_mask = (cap == self.pad_idx).to(cap.device) + # cap_padding_mask = cap_padding_mask[:, :-1] + # output = self.decoder( + # { + # "word": cap[:, :-1], + # "attn_emb": input_dict["attn_emb"], + # "attn_emb_len": input_dict["attn_emb_len"], + # "cap_padding_mask": cap_padding_mask + # } + # ) + # return output + + def prepare_decoder_input(self, input_dict, output): + decoder_input = super().prepare_decoder_input(input_dict, output) + decoder_input["events"] = self.label_encoder(input_dict["events"]) + return decoder_input + + def prepare_beamsearch_decoder_input(self, input_dict, output_i): + decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) + t = input_dict["t"] + i = input_dict["sample_idx"] + beam_size = input_dict["beam_size"] + if t == 0: + output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size) + decoder_input["events"] = output_i["events"] + return decoder_input + + +class KeywordCondTransformerModel(TransformerModel): + + def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): + if not hasattr(self, "compatible_decoders"): + self.compatible_decoders = ( + captioning.models.decoder.KeywordProbTransformerDecoder, + ) + super().__init__(encoder, decoder, **kwargs) + self.train_forward_keys += ["keyword"] + self.inference_forward_keys += ["keyword"] + + def seq_forward(self, input_dict): + cap = input_dict["cap"] + cap_padding_mask = (cap == self.pad_idx).to(cap.device) + cap_padding_mask = cap_padding_mask[:, :-1] + keyword = input_dict["keyword"] + output = self.decoder( + { + "word": cap[:, :-1], + "attn_emb": input_dict["attn_emb"], + "attn_emb_len": input_dict["attn_emb_len"], + "keyword": keyword, + "cap_padding_mask": cap_padding_mask + } + ) + return output + + def prepare_decoder_input(self, input_dict, output): + decoder_input = super().prepare_decoder_input(input_dict, output) + decoder_input["keyword"] = input_dict["keyword"] + return decoder_input + + def prepare_beamsearch_decoder_input(self, input_dict, output_i): + decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) + t = input_dict["t"] + i = input_dict["sample_idx"] + beam_size = input_dict["beam_size"] + if t == 0: + output_i["keyword"] = repeat_tensor(input_dict["keyword"][i], + beam_size) + decoder_input["keyword"] = output_i["keyword"] + return decoder_input + diff --git a/audio_to_text/captioning/models/utils.py b/audio_to_text/captioning/models/utils.py new file mode 100644 index 0000000..3623cf4 --- /dev/null +++ b/audio_to_text/captioning/models/utils.py @@ -0,0 +1,132 @@ +import math + +import numpy as np +import torch +import torch.nn as nn + +from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence + + +def sort_pack_padded_sequence(input, lengths): + sorted_lengths, indices = torch.sort(lengths, descending=True) + tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) + inv_ix = indices.clone() + inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) + return tmp, inv_ix + +def pad_unsort_packed_sequence(input, inv_ix): + tmp, _ = pad_packed_sequence(input, batch_first=True) + tmp = tmp[inv_ix] + return tmp + +def pack_wrapper(module, attn_feats, attn_feat_lens): + packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens) + if isinstance(module, torch.nn.RNNBase): + return pad_unsort_packed_sequence(module(packed)[0], inv_ix) + else: + return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) + +def generate_length_mask(lens, max_length=None): + lens = torch.as_tensor(lens) + N = lens.size(0) + if max_length is None: + max_length = max(lens) + idxs = torch.arange(max_length).repeat(N).view(N, max_length) + idxs = idxs.to(lens.device) + mask = (idxs < lens.view(-1, 1)) + return mask + +def mean_with_lens(features, lens): + """ + features: [N, T, ...] (assume the second dimension represents length) + lens: [N,] + """ + lens = torch.as_tensor(lens) + if max(lens) != features.size(1): + max_length = features.size(1) + mask = generate_length_mask(lens, max_length) + else: + mask = generate_length_mask(lens) + mask = mask.to(features.device) # [N, T] + + while mask.ndim < features.ndim: + mask = mask.unsqueeze(-1) + feature_mean = features * mask + feature_mean = feature_mean.sum(1) + while lens.ndim < feature_mean.ndim: + lens = lens.unsqueeze(1) + feature_mean = feature_mean / lens.to(features.device) + # feature_mean = features * mask.unsqueeze(-1) + # feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device) + return feature_mean + +def max_with_lens(features, lens): + """ + features: [N, T, ...] (assume the second dimension represents length) + lens: [N,] + """ + lens = torch.as_tensor(lens) + mask = generate_length_mask(lens).to(features.device) # [N, T] + + feature_max = features.clone() + feature_max[~mask] = float("-inf") + feature_max, _ = feature_max.max(1) + return feature_max + +def repeat_tensor(x, n): + return x.unsqueeze(0).repeat(n, *([1] * len(x.shape))) + +def init(m, method="kaiming"): + if isinstance(m, (nn.Conv2d, nn.Conv1d)): + if method == "kaiming": + nn.init.kaiming_uniform_(m.weight) + elif method == "xavier": + nn.init.xavier_uniform_(m.weight) + else: + raise Exception(f"initialization method {method} not supported") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + if method == "kaiming": + nn.init.kaiming_uniform_(m.weight) + elif method == "xavier": + nn.init.xavier_uniform_(m.weight) + else: + raise Exception(f"initialization method {method} not supported") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Embedding): + if method == "kaiming": + nn.init.kaiming_uniform_(m.weight) + elif method == "xavier": + nn.init.xavier_uniform_(m.weight) + else: + raise Exception(f"initialization method {method} not supported") + + + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=100): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * \ + (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + # self.register_buffer("pe", pe) + self.register_parameter("pe", nn.Parameter(pe, requires_grad=False)) + + def forward(self, x): + # x: [T, N, E] + x = x + self.pe[:x.size(0), :] + return self.dropout(x) diff --git a/audio_to_text/captioning/utils/README.md b/audio_to_text/captioning/utils/README.md new file mode 100644 index 0000000..c6fd17d --- /dev/null +++ b/audio_to_text/captioning/utils/README.md @@ -0,0 +1,19 @@ +# Utils + +Scripts in this directory are used as utility functions. + +## BERT Pretrained Embeddings + +You can load pretrained word embeddings in Google [BERT](https://github.com/google-research/bert#pre-trained-models) instead of training word embeddings from scratch. The scripts in `utils/bert` need a BERT server in the background. We use BERT server from [bert-as-service](https://github.com/hanxiao/bert-as-service). + +To use bert-as-service, you need to first install the repository. It is recommended that you create a new environment with Tensorflow 1.3 to run BERT server since it is incompatible with Tensorflow 2.x. + +After successful installation of [bert-as-service](https://github.com/hanxiao/bert-as-service), downloading and running the BERT server needs to execute: + +```bash +bash scripts/prepare_bert_server.sh zh +``` + +By default, server based on BERT base Chinese model is running in the background. You can change to other models by changing corresponding model name and path in `scripts/prepare_bert_server.sh`. + +To extract BERT word embeddings, you need to execute `utils/bert/create_word_embedding.py`. diff --git a/audio_to_text/captioning/utils/__init__.py b/audio_to_text/captioning/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/audio_to_text/captioning/utils/bert/create_sent_embedding.py b/audio_to_text/captioning/utils/bert/create_sent_embedding.py new file mode 100644 index 0000000..b517a32 --- /dev/null +++ b/audio_to_text/captioning/utils/bert/create_sent_embedding.py @@ -0,0 +1,89 @@ +import pickle +import fire +import numpy as np +import pandas as pd +from tqdm import tqdm + + +class EmbeddingExtractor(object): + + def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False): + from sentence_transformers import SentenceTransformer + lang2model = { + "zh": "distiluse-base-multilingual-cased", + "en": "bert-base-nli-mean-tokens" + } + lang = "zh" if zh else "en" + model = SentenceTransformer(lang2model[lang]) + + self.extract(caption_file, model, output, dev) + + def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"): + from bert_serving.client import BertClient + client = BertClient(ip) + + self.extract(caption_file, client, output, dev) + + def extract(self, caption_file: str, model, output, dev: bool): + caption_df = pd.read_json(caption_file, dtype={"key": str}) + embeddings = {} + + if dev: + with tqdm(total=caption_df.shape[0], ascii=True) as pbar: + for idx, row in caption_df.iterrows(): + caption = row["caption"] + key = row["key"] + cap_idx = row["caption_index"] + embedding = model.encode([caption]) + embedding = np.array(embedding).reshape(-1) + embeddings[f"{key}_{cap_idx}"] = embedding + pbar.update() + + else: + dump = {} + + with tqdm(total=caption_df.shape[0], ascii=True) as pbar: + for idx, row in caption_df.iterrows(): + key = row["key"] + caption = row["caption"] + value = np.array(model.encode([caption])).reshape(-1) + + if key not in embeddings.keys(): + embeddings[key] = [value] + else: + embeddings[key].append(value) + + pbar.update() + + for key in embeddings: + dump[key] = np.stack(embeddings[key]) + + embeddings = dump + + with open(output, "wb") as f: + pickle.dump(embeddings, f) + + def extract_sbert(self, + input_json: str, + output: str): + from sentence_transformers import SentenceTransformer + import json + import torch + from h5py import File + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SentenceTransformer("paraphrase-MiniLM-L6-v2") + model = model.to(device) + model.eval() + + data = json.load(open(input_json))["audios"] + with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store: + for sample in data: + audio_id = sample["audio_id"] + for cap in sample["captions"]: + cap_id = cap["cap_id"] + store[f"{audio_id}_{cap_id}"] = model.encode(cap["caption"]) + pbar.update() + + +if __name__ == "__main__": + fire.Fire(EmbeddingExtractor) diff --git a/audio_to_text/captioning/utils/bert/create_word_embedding.py b/audio_to_text/captioning/utils/bert/create_word_embedding.py new file mode 100644 index 0000000..43c980e --- /dev/null +++ b/audio_to_text/captioning/utils/bert/create_word_embedding.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +import sys +import os + +from bert_serving.client import BertClient +import numpy as np +from tqdm import tqdm +import fire +import torch + +sys.path.append(os.getcwd()) +from utils.build_vocab import Vocabulary + +def main(vocab_file: str, output: str, server_hostname: str): + client = BertClient(ip=server_hostname) + vocabulary = torch.load(vocab_file) + vocab_size = len(vocabulary) + + fake_embedding = client.encode(["test"]).reshape(-1) + embed_size = fake_embedding.shape[0] + + print("Encoding words into embeddings with size: ", embed_size) + + embeddings = np.empty((vocab_size, embed_size)) + for i in tqdm(range(len(embeddings)), ascii=True): + embeddings[i] = client.encode([vocabulary.idx2word[i]]) + np.save(output, embeddings) + + +if __name__ == '__main__': + fire.Fire(main) + + diff --git a/audio_to_text/captioning/utils/build_vocab.py b/audio_to_text/captioning/utils/build_vocab.py new file mode 100644 index 0000000..e9fab23 --- /dev/null +++ b/audio_to_text/captioning/utils/build_vocab.py @@ -0,0 +1,153 @@ +import json +from tqdm import tqdm +import logging +import pickle +from collections import Counter +import re +import fire + + +class Vocabulary(object): + """Simple vocabulary wrapper.""" + def __init__(self): + self.word2idx = {} + self.idx2word = {} + self.idx = 0 + + def add_word(self, word): + if not word in self.word2idx: + self.word2idx[word] = self.idx + self.idx2word[self.idx] = word + self.idx += 1 + + def __call__(self, word): + if not word in self.word2idx: + return self.word2idx[""] + return self.word2idx[word] + + def __getitem__(self, word_id): + return self.idx2word[word_id] + + def __len__(self): + return len(self.word2idx) + + +def build_vocab(input_json: str, + threshold: int, + keep_punctuation: bool, + host_address: str, + character_level: bool = False, + zh: bool = True ): + """Build vocabulary from csv file with a given threshold to drop all counts < threshold + + Args: + input_json(string): Preprossessed json file. Structure like this: + { + 'audios': [ + { + 'audio_id': 'xxx', + 'captions': [ + { + 'caption': 'xxx', + 'cap_id': 'xxx' + } + ] + }, + ... + ] + } + threshold (int): Threshold to drop all words with counts < threshold + keep_punctuation (bool): Includes or excludes punctuation. + + Returns: + vocab (Vocab): Object with the processed vocabulary +""" + data = json.load(open(input_json, "r"))["audios"] + counter = Counter() + pretokenized = "tokens" in data[0]["captions"][0] + + if zh: + from nltk.parse.corenlp import CoreNLPParser + from zhon.hanzi import punctuation + if not pretokenized: + parser = CoreNLPParser(host_address) + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + for cap_idx in range(len(data[audio_idx]["captions"])): + if pretokenized: + tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split() + else: + caption = data[audio_idx]["captions"][cap_idx]["caption"] + # Remove all punctuations + if not keep_punctuation: + caption = re.sub("[{}]".format(punctuation), "", caption) + if character_level: + tokens = list(caption) + else: + tokens = list(parser.tokenize(caption)) + data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens) + counter.update(tokens) + else: + if pretokenized: + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + for cap_idx in range(len(data[audio_idx]["captions"])): + tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split() + counter.update(tokens) + else: + from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer + captions = {} + for audio_idx in range(len(data)): + audio_id = data[audio_idx]["audio_id"] + captions[audio_id] = [] + for cap_idx in range(len(data[audio_idx]["captions"])): + caption = data[audio_idx]["captions"][cap_idx]["caption"] + captions[audio_id].append({ + "audio_id": audio_id, + "id": cap_idx, + "caption": caption + }) + tokenizer = PTBTokenizer() + captions = tokenizer.tokenize(captions) + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + audio_id = data[audio_idx]["audio_id"] + for cap_idx in range(len(data[audio_idx]["captions"])): + tokens = captions[audio_id][cap_idx] + data[audio_idx]["captions"][cap_idx]["tokens"] = tokens + counter.update(tokens.split(" ")) + + if not pretokenized: + json.dump({ "audios": data }, open(input_json, "w"), indent=4, ensure_ascii=not zh) + words = [word for word, cnt in counter.items() if cnt >= threshold] + + # Create a vocab wrapper and add some special tokens. + vocab = Vocabulary() + vocab.add_word("") + vocab.add_word("") + vocab.add_word("") + vocab.add_word("") + + # Add the words to the vocabulary. + for word in words: + vocab.add_word(word) + return vocab + + +def process(input_json: str, + output_file: str, + threshold: int = 1, + keep_punctuation: bool = False, + character_level: bool = False, + host_address: str = "http://localhost:9000", + zh: bool = False): + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + logging.basicConfig(level=logging.INFO, format=logfmt) + logging.info("Build Vocab") + vocabulary = build_vocab( + input_json=input_json, threshold=threshold, keep_punctuation=keep_punctuation, + host_address=host_address, character_level=character_level, zh=zh) + pickle.dump(vocabulary, open(output_file, "wb")) + logging.info("Total vocabulary size: {}".format(len(vocabulary))) + logging.info("Saved vocab to '{}'".format(output_file)) + + +if __name__ == '__main__': + fire.Fire(process) diff --git a/audio_to_text/captioning/utils/build_vocab_ltp.py b/audio_to_text/captioning/utils/build_vocab_ltp.py new file mode 100644 index 0000000..aae0c71 --- /dev/null +++ b/audio_to_text/captioning/utils/build_vocab_ltp.py @@ -0,0 +1,150 @@ +import json +from tqdm import tqdm +import logging +import pickle +from collections import Counter +import re +import fire + +class Vocabulary(object): + """Simple vocabulary wrapper.""" + def __init__(self): + self.word2idx = {} + self.idx2word = {} + self.idx = 0 + + def add_word(self, word): + if not word in self.word2idx: + self.word2idx[word] = self.idx + self.idx2word[self.idx] = word + self.idx += 1 + + def __call__(self, word): + if not word in self.word2idx: + return self.word2idx[""] + return self.word2idx[word] + + def __len__(self): + return len(self.word2idx) + +def build_vocab(input_json: str, + output_json: str, + threshold: int, + keep_punctuation: bool, + character_level: bool = False, + zh: bool = True ): + """Build vocabulary from csv file with a given threshold to drop all counts < threshold + + Args: + input_json(string): Preprossessed json file. Structure like this: + { + 'audios': [ + { + 'audio_id': 'xxx', + 'captions': [ + { + 'caption': 'xxx', + 'cap_id': 'xxx' + } + ] + }, + ... + ] + } + threshold (int): Threshold to drop all words with counts < threshold + keep_punctuation (bool): Includes or excludes punctuation. + + Returns: + vocab (Vocab): Object with the processed vocabulary +""" + data = json.load(open(input_json, "r"))["audios"] + counter = Counter() + pretokenized = "tokens" in data[0]["captions"][0] + + if zh: + from ltp import LTP + from zhon.hanzi import punctuation + if not pretokenized: + parser = LTP("base") + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + for cap_idx in range(len(data[audio_idx]["captions"])): + if pretokenized: + tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split() + else: + caption = data[audio_idx]["captions"][cap_idx]["caption"] + if character_level: + tokens = list(caption) + else: + tokens, _ = parser.seg([caption]) + tokens = tokens[0] + # Remove all punctuations + if not keep_punctuation: + tokens = [token for token in tokens if token not in punctuation] + data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens) + counter.update(tokens) + else: + if pretokenized: + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + for cap_idx in range(len(data[audio_idx]["captions"])): + tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split() + counter.update(tokens) + else: + from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer + captions = {} + for audio_idx in range(len(data)): + audio_id = data[audio_idx]["audio_id"] + captions[audio_id] = [] + for cap_idx in range(len(data[audio_idx]["captions"])): + caption = data[audio_idx]["captions"][cap_idx]["caption"] + captions[audio_id].append({ + "audio_id": audio_id, + "id": cap_idx, + "caption": caption + }) + tokenizer = PTBTokenizer() + captions = tokenizer.tokenize(captions) + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + audio_id = data[audio_idx]["audio_id"] + for cap_idx in range(len(data[audio_idx]["captions"])): + tokens = captions[audio_id][cap_idx] + data[audio_idx]["captions"][cap_idx]["tokens"] = tokens + counter.update(tokens.split(" ")) + + if not pretokenized: + if output_json is None: + output_json = input_json + json.dump({ "audios": data }, open(output_json, "w"), indent=4, ensure_ascii=not zh) + words = [word for word, cnt in counter.items() if cnt >= threshold] + + # Create a vocab wrapper and add some special tokens. + vocab = Vocabulary() + vocab.add_word("") + vocab.add_word("") + vocab.add_word("") + vocab.add_word("") + + # Add the words to the vocabulary. + for word in words: + vocab.add_word(word) + return vocab + +def process(input_json: str, + output_file: str, + output_json: str = None, + threshold: int = 1, + keep_punctuation: bool = False, + character_level: bool = False, + zh: bool = True): + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + logging.basicConfig(level=logging.INFO, format=logfmt) + logging.info("Build Vocab") + vocabulary = build_vocab( + input_json=input_json, output_json=output_json, threshold=threshold, + keep_punctuation=keep_punctuation, character_level=character_level, zh=zh) + pickle.dump(vocabulary, open(output_file, "wb")) + logging.info("Total vocabulary size: {}".format(len(vocabulary))) + logging.info("Saved vocab to '{}'".format(output_file)) + + +if __name__ == '__main__': + fire.Fire(process) diff --git a/audio_to_text/captioning/utils/build_vocab_spacy.py b/audio_to_text/captioning/utils/build_vocab_spacy.py new file mode 100644 index 0000000..84da679 --- /dev/null +++ b/audio_to_text/captioning/utils/build_vocab_spacy.py @@ -0,0 +1,152 @@ +import json +from tqdm import tqdm +import logging +import pickle +from collections import Counter +import re +import fire + +class Vocabulary(object): + """Simple vocabulary wrapper.""" + def __init__(self): + self.word2idx = {} + self.idx2word = {} + self.idx = 0 + + def add_word(self, word): + if not word in self.word2idx: + self.word2idx[word] = self.idx + self.idx2word[self.idx] = word + self.idx += 1 + + def __call__(self, word): + if not word in self.word2idx: + return self.word2idx[""] + return self.word2idx[word] + + def __len__(self): + return len(self.word2idx) + + +def build_vocab(input_json: str, + output_json: str, + threshold: int, + keep_punctuation: bool, + host_address: str, + character_level: bool = False, + retokenize: bool = True, + zh: bool = True ): + """Build vocabulary from csv file with a given threshold to drop all counts < threshold + + Args: + input_json(string): Preprossessed json file. Structure like this: + { + 'audios': [ + { + 'audio_id': 'xxx', + 'captions': [ + { + 'caption': 'xxx', + 'cap_id': 'xxx' + } + ] + }, + ... + ] + } + threshold (int): Threshold to drop all words with counts < threshold + keep_punctuation (bool): Includes or excludes punctuation. + + Returns: + vocab (Vocab): Object with the processed vocabulary +""" + data = json.load(open(input_json, "r"))["audios"] + counter = Counter() + if retokenize: + pretokenized = False + else: + pretokenized = "tokens" in data[0]["captions"][0] + + if zh: + from nltk.parse.corenlp import CoreNLPParser + from zhon.hanzi import punctuation + if not pretokenized: + parser = CoreNLPParser(host_address) + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + for cap_idx in range(len(data[audio_idx]["captions"])): + if pretokenized: + tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split() + else: + caption = data[audio_idx]["captions"][cap_idx]["caption"] + # Remove all punctuations + if not keep_punctuation: + caption = re.sub("[{}]".format(punctuation), "", caption) + if character_level: + tokens = list(caption) + else: + tokens = list(parser.tokenize(caption)) + data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens) + counter.update(tokens) + else: + if pretokenized: + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + for cap_idx in range(len(data[audio_idx]["captions"])): + tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split() + counter.update(tokens) + else: + import spacy + tokenizer = spacy.load("en_core_web_sm", disable=["parser", "ner"]) + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + captions = data[audio_idx]["captions"] + for cap_idx in range(len(captions)): + caption = captions[cap_idx]["caption"] + doc = tokenizer(caption) + tokens = " ".join([str(token).lower() for token in doc]) + data[audio_idx]["captions"][cap_idx]["tokens"] = tokens + counter.update(tokens.split(" ")) + + if not pretokenized: + if output_json is None: + json.dump({ "audios": data }, open(input_json, "w"), + indent=4, ensure_ascii=not zh) + else: + json.dump({ "audios": data }, open(output_json, "w"), + indent=4, ensure_ascii=not zh) + + words = [word for word, cnt in counter.items() if cnt >= threshold] + + # Create a vocab wrapper and add some special tokens. + vocab = Vocabulary() + vocab.add_word("") + vocab.add_word("") + vocab.add_word("") + vocab.add_word("") + + # Add the words to the vocabulary. + for word in words: + vocab.add_word(word) + return vocab + +def process(input_json: str, + output_file: str, + output_json: str = None, + threshold: int = 1, + keep_punctuation: bool = False, + character_level: bool = False, + retokenize: bool = False, + host_address: str = "http://localhost:9000", + zh: bool = True): + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + logging.basicConfig(level=logging.INFO, format=logfmt) + logging.info("Build Vocab") + vocabulary = build_vocab( + input_json=input_json, output_json=output_json, threshold=threshold, + keep_punctuation=keep_punctuation, host_address=host_address, + character_level=character_level, retokenize=retokenize, zh=zh) + pickle.dump(vocabulary, open(output_file, "wb")) + logging.info("Total vocabulary size: {}".format(len(vocabulary))) + logging.info("Saved vocab to '{}'".format(output_file)) + + +if __name__ == '__main__': + fire.Fire(process) diff --git a/audio_to_text/captioning/utils/eval_round_robin.py b/audio_to_text/captioning/utils/eval_round_robin.py new file mode 100644 index 0000000..28603a5 --- /dev/null +++ b/audio_to_text/captioning/utils/eval_round_robin.py @@ -0,0 +1,182 @@ +import copy +import json + +import numpy as np +import fire + + +def evaluate_annotation(key2refs, scorer): + if scorer.method() == "Bleu": + scores = np.array([ 0.0 for n in range(4) ]) + else: + scores = 0 + num_cap_per_audio = len(next(iter(key2refs.values()))) + + for i in range(num_cap_per_audio): + if i > 0: + for key in key2refs: + key2refs[key].insert(0, res[key][0]) + res = { key: [refs.pop(),] for key, refs in key2refs.items() } + score, _ = scorer.compute_score(key2refs, res) + + if scorer.method() == "Bleu": + scores += np.array(score) + else: + scores += score + + score = scores / num_cap_per_audio + return score + +def evaluate_prediction(key2pred, key2refs, scorer): + if scorer.method() == "Bleu": + scores = np.array([ 0.0 for n in range(4) ]) + else: + scores = 0 + num_cap_per_audio = len(next(iter(key2refs.values()))) + + for i in range(num_cap_per_audio): + key2refs_i = {} + for key, refs in key2refs.items(): + key2refs_i[key] = refs[:i] + refs[i+1:] + score, _ = scorer.compute_score(key2refs_i, key2pred) + + if scorer.method() == "Bleu": + scores += np.array(score) + else: + scores += score + + score = scores / num_cap_per_audio + return score + + +class Evaluator(object): + + def eval_annotation(self, annotation, output): + captions = json.load(open(annotation, "r"))["audios"] + + key2refs = {} + for audio_idx in range(len(captions)): + audio_id = captions[audio_idx]["audio_id"] + key2refs[audio_id] = [] + for caption in captions[audio_idx]["captions"]: + key2refs[audio_id].append(caption["caption"]) + + from fense.fense import Fense + scores = {} + scorer = Fense() + scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer) + + refs4eval = {} + for key, refs in key2refs.items(): + refs4eval[key] = [] + for idx, ref in enumerate(refs): + refs4eval[key].append({ + "audio_id": key, + "id": idx, + "caption": ref + }) + + from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer + + tokenizer = PTBTokenizer() + key2refs = tokenizer.tokenize(refs4eval) + + + from pycocoevalcap.bleu.bleu import Bleu + from pycocoevalcap.cider.cider import Cider + from pycocoevalcap.rouge.rouge import Rouge + from pycocoevalcap.meteor.meteor import Meteor + from pycocoevalcap.spice.spice import Spice + + + scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()] + for scorer in scorers: + scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer) + + spider = 0 + with open(output, "w") as f: + for name, score in scores.items(): + if name == "Bleu": + for n in range(4): + f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n])) + else: + f.write("{}: {:6.3f}\n".format(name, score)) + if name in ["CIDEr", "SPICE"]: + spider += score + f.write("SPIDEr: {:6.3f}\n".format(spider / 2)) + + def eval_prediction(self, prediction, annotation, output): + ref_captions = json.load(open(annotation, "r"))["audios"] + + key2refs = {} + for audio_idx in range(len(ref_captions)): + audio_id = ref_captions[audio_idx]["audio_id"] + key2refs[audio_id] = [] + for caption in ref_captions[audio_idx]["captions"]: + key2refs[audio_id].append(caption["caption"]) + + pred_captions = json.load(open(prediction, "r"))["predictions"] + + key2pred = {} + for audio_idx in range(len(pred_captions)): + item = pred_captions[audio_idx] + audio_id = item["filename"] + key2pred[audio_id] = [item["tokens"]] + + from fense.fense import Fense + scores = {} + scorer = Fense() + scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer) + + refs4eval = {} + for key, refs in key2refs.items(): + refs4eval[key] = [] + for idx, ref in enumerate(refs): + refs4eval[key].append({ + "audio_id": key, + "id": idx, + "caption": ref + }) + + preds4eval = {} + for key, preds in key2pred.items(): + preds4eval[key] = [] + for idx, pred in enumerate(preds): + preds4eval[key].append({ + "audio_id": key, + "id": idx, + "caption": pred + }) + + from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer + + tokenizer = PTBTokenizer() + key2refs = tokenizer.tokenize(refs4eval) + key2pred = tokenizer.tokenize(preds4eval) + + + from pycocoevalcap.bleu.bleu import Bleu + from pycocoevalcap.cider.cider import Cider + from pycocoevalcap.rouge.rouge import Rouge + from pycocoevalcap.meteor.meteor import Meteor + from pycocoevalcap.spice.spice import Spice + + scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()] + for scorer in scorers: + scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer) + + spider = 0 + with open(output, "w") as f: + for name, score in scores.items(): + if name == "Bleu": + for n in range(4): + f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n])) + else: + f.write("{}: {:6.3f}\n".format(name, score)) + if name in ["CIDEr", "SPICE"]: + spider += score + f.write("SPIDEr: {:6.3f}\n".format(spider / 2)) + + +if __name__ == "__main__": + fire.Fire(Evaluator) diff --git a/audio_to_text/captioning/utils/fasttext/create_word_embedding.py b/audio_to_text/captioning/utils/fasttext/create_word_embedding.py new file mode 100644 index 0000000..09da13a --- /dev/null +++ b/audio_to_text/captioning/utils/fasttext/create_word_embedding.py @@ -0,0 +1,50 @@ +# coding=utf-8 +#!/usr/bin/env python3 + +import numpy as np +import pandas as pd +import torch +from gensim.models import FastText +from tqdm import tqdm +import fire + +import sys +import os +sys.path.append(os.getcwd()) +from utils.build_vocab import Vocabulary + +def create_embedding(caption_file: str, + vocab_file: str, + embed_size: int, + output: str, + **fasttext_kwargs): + caption_df = pd.read_json(caption_file) + caption_df["tokens"] = caption_df["tokens"].apply(lambda x: [""] + [token for token in x] + [""]) + + sentences = list(caption_df["tokens"].values) + vocabulary = torch.load(vocab_file, map_location="cpu") + + epochs = fasttext_kwargs.get("epochs", 10) + model = FastText(size=embed_size, min_count=1, **fasttext_kwargs) + model.build_vocab(sentences=sentences) + model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs) + + word_embeddings = np.zeros((len(vocabulary), embed_size)) + + with tqdm(total=len(vocabulary), ascii=True) as pbar: + for word, idx in vocabulary.word2idx.items(): + if word == "" or word == "": + continue + word_embeddings[idx] = model.wv[word] + pbar.update() + + np.save(output, word_embeddings) + + print("Finish writing fasttext embeddings to " + output) + + +if __name__ == "__main__": + fire.Fire(create_embedding) + + + diff --git a/audio_to_text/captioning/utils/lr_scheduler.py b/audio_to_text/captioning/utils/lr_scheduler.py new file mode 100644 index 0000000..b46e3f0 --- /dev/null +++ b/audio_to_text/captioning/utils/lr_scheduler.py @@ -0,0 +1,128 @@ +import math +import torch + + +class ExponentialDecayScheduler(torch.optim.lr_scheduler._LRScheduler): + + def __init__(self, optimizer, total_iters, final_lrs, + warmup_iters=3000, last_epoch=-1, verbose=False): + self.total_iters = total_iters + self.final_lrs = final_lrs + if not isinstance(self.final_lrs, list) and not isinstance( + self.final_lrs, tuple): + self.final_lrs = [self.final_lrs] * len(optimizer.param_groups) + self.warmup_iters = warmup_iters + self.bases = [0.0,] * len(optimizer.param_groups) + super().__init__(optimizer, last_epoch, verbose) + for i, (base_lr, final_lr) in enumerate(zip(self.base_lrs, self.final_lrs)): + base = (final_lr / base_lr) ** (1 / ( + self.total_iters - self.warmup_iters)) + self.bases[i] = base + + def _get_closed_form_lr(self): + warmup_coeff = 1.0 + current_iter = self._step_count + if current_iter < self.warmup_iters: + warmup_coeff = current_iter / self.warmup_iters + current_lrs = [] + # if not self.linear_warmup: + # for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs, self.bases): + # # current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr)) + # current_lr = warmup_coeff * base_lr * (base ** (current_iter - self.warmup_iters)) + # current_lrs.append(current_lr) + # else: + for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs, + self.bases): + if current_iter <= self.warmup_iters: + current_lr = warmup_coeff * base_lr + else: + # current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr)) + current_lr = base_lr * (base ** (current_iter - self.warmup_iters)) + current_lrs.append(current_lr) + return current_lrs + + def get_lr(self): + return self._get_closed_form_lr() + + +class NoamScheduler(torch.optim.lr_scheduler._LRScheduler): + + def __init__(self, optimizer, model_size=512, factor=1, warmup_iters=3000, + last_epoch=-1, verbose=False): + self.model_size = model_size + self.warmup_iters = warmup_iters + # self.factors = [group["lr"] / (self.model_size ** (-0.5) * self.warmup_iters ** (-0.5)) for group in optimizer.param_groups] + self.factor = factor + super().__init__(optimizer, last_epoch, verbose) + + def _get_closed_form_lr(self): + current_iter = self._step_count + current_lrs = [] + for _ in self.base_lrs: + current_lr = self.factor * \ + (self.model_size ** (-0.5) * min(current_iter ** (-0.5), + current_iter * self.warmup_iters ** (-1.5))) + current_lrs.append(current_lr) + return current_lrs + + def get_lr(self): + return self._get_closed_form_lr() + + +class CosineWithWarmup(torch.optim.lr_scheduler._LRScheduler): + + def __init__(self, optimizer, total_iters, warmup_iters, + num_cycles=0.5, last_epoch=-1, verbose=False): + self.total_iters = total_iters + self.warmup_iters = warmup_iters + self.num_cycles = num_cycles + super().__init__(optimizer, last_epoch, verbose) + + def lr_lambda(self, iteration): + if iteration < self.warmup_iters: + return float(iteration) / float(max(1, self.warmup_iters)) + progress = float(iteration - self.warmup_iters) / float(max(1, + self.total_iters - self.warmup_iters)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float( + self.num_cycles) * 2.0 * progress))) + + def _get_closed_form_lr(self): + current_iter = self._step_count + current_lrs = [] + for base_lr in self.base_lrs: + current_lr = base_lr * self.lr_lambda(current_iter) + current_lrs.append(current_lr) + return current_lrs + + def get_lr(self): + return self._get_closed_form_lr() + + +if __name__ == "__main__": + model = torch.nn.Linear(10, 5) + optimizer = torch.optim.Adam(model.parameters(), 5e-4) + epochs = 25 + iters = 600 + scheduler = CosineWithWarmup(optimizer, 600 * 25, 600 * 5,) + # scheduler = ExponentialDecayScheduler(optimizer, 600 * 25, 5e-7, 600 * 5) + criterion = torch.nn.MSELoss() + lrs = [] + for epoch in range(1, epochs + 1): + for iteration in range(1, iters + 1): + optimizer.zero_grad() + x = torch.randn(4, 10) + y = torch.randn(4, 5) + loss = criterion(model(x), y) + loss.backward() + optimizer.step() + scheduler.step() + # print(f"lr: {scheduler.get_last_lr()}") + # lrs.append(scheduler.get_last_lr()) + lrs.append(optimizer.param_groups[0]["lr"]) + import matplotlib.pyplot as plt + plt.plot(list(range(1, len(lrs) + 1)), lrs, '-o', markersize=1) + # plt.legend(loc="best") + plt.xlabel("Iteration") + plt.ylabel("LR") + + plt.savefig("lr_curve.png", dpi=100) diff --git a/audio_to_text/captioning/utils/model_eval_diff.py b/audio_to_text/captioning/utils/model_eval_diff.py new file mode 100644 index 0000000..2c29ef8 --- /dev/null +++ b/audio_to_text/captioning/utils/model_eval_diff.py @@ -0,0 +1,110 @@ +import os +import sys +import copy +import pickle + +import numpy as np +import pandas as pd +import fire + +sys.path.append(os.getcwd()) + + +def coco_score(refs, pred, scorer): + if scorer.method() == "Bleu": + scores = np.array([ 0.0 for n in range(4) ]) + else: + scores = 0 + num_cap_per_audio = len(refs[list(refs.keys())[0]]) + + for i in range(num_cap_per_audio): + if i > 0: + for key in refs: + refs[key].insert(0, res[key][0]) + res = {key: [refs[key].pop(),] for key in refs} + score, _ = scorer.compute_score(refs, pred) + + if scorer.method() == "Bleu": + scores += np.array(score) + else: + scores += score + + score = scores / num_cap_per_audio + + for key in refs: + refs[key].insert(0, res[key][0]) + score_allref, _ = scorer.compute_score(refs, pred) + diff = score_allref - score + return diff + +def embedding_score(refs, pred, scorer): + + num_cap_per_audio = len(refs[list(refs.keys())[0]]) + scores = 0 + + for i in range(num_cap_per_audio): + res = {key: [refs[key][i],] for key in refs.keys() if len(refs[key]) == num_cap_per_audio} + refs_i = {key: np.concatenate([refs[key][:i], refs[key][i+1:]]) for key in refs.keys() if len(refs[key]) == num_cap_per_audio} + score, _ = scorer.compute_score(refs_i, pred) + + scores += score + + score = scores / num_cap_per_audio + + score_allref, _ = scorer.compute_score(refs, pred) + diff = score_allref - score + return diff + +def main(output_file, eval_caption_file, eval_embedding_file, output, zh=False): + output_df = pd.read_json(output_file) + output_df["key"] = output_df["filename"].apply(lambda x: os.path.splitext(os.path.basename(x))[0]) + pred = output_df.groupby("key")["tokens"].apply(list).to_dict() + + label_df = pd.read_json(eval_caption_file) + if zh: + refs = label_df.groupby("key")["tokens"].apply(list).to_dict() + else: + refs = label_df.groupby("key")["caption"].apply(list).to_dict() + + from pycocoevalcap.bleu.bleu import Bleu + from pycocoevalcap.cider.cider import Cider + from pycocoevalcap.rouge.rouge import Rouge + + scorer = Bleu(zh=zh) + bleu_scores = coco_score(copy.deepcopy(refs), pred, scorer) + scorer = Cider(zh=zh) + cider_score = coco_score(copy.deepcopy(refs), pred, scorer) + scorer = Rouge(zh=zh) + rouge_score = coco_score(copy.deepcopy(refs), pred, scorer) + + if not zh: + from pycocoevalcap.meteor.meteor import Meteor + scorer = Meteor() + meteor_score = coco_score(copy.deepcopy(refs), pred, scorer) + + from pycocoevalcap.spice.spice import Spice + scorer = Spice() + spice_score = coco_score(copy.deepcopy(refs), pred, scorer) + + # from audiocaptioneval.sentbert.sentencebert import SentenceBert + # scorer = SentenceBert(zh=zh) + # with open(eval_embedding_file, "rb") as f: + # ref_embeddings = pickle.load(f) + + # sent_bert = embedding_score(ref_embeddings, pred, scorer) + + with open(output, "w") as f: + f.write("Diff:\n") + for n in range(4): + f.write("BLEU-{}: {:6.3f}\n".format(n+1, bleu_scores[n])) + f.write("CIDEr: {:6.3f}\n".format(cider_score)) + f.write("ROUGE: {:6.3f}\n".format(rouge_score)) + if not zh: + f.write("Meteor: {:6.3f}\n".format(meteor_score)) + f.write("SPICE: {:6.3f}\n".format(spice_score)) + # f.write("SentenceBert: {:6.3f}\n".format(sent_bert)) + + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/audio_to_text/captioning/utils/predict_nn.py b/audio_to_text/captioning/utils/predict_nn.py new file mode 100644 index 0000000..699c3dc --- /dev/null +++ b/audio_to_text/captioning/utils/predict_nn.py @@ -0,0 +1,49 @@ +import json +import random +import argparse +import numpy as np +from tqdm import tqdm +from h5py import File +import sklearn.metrics + +random.seed(1) + +parser = argparse.ArgumentParser() +parser.add_argument("train_feature", type=str) +parser.add_argument("train_corpus", type=str) +parser.add_argument("pred_feature", type=str) +parser.add_argument("output_json", type=str) + +args = parser.parse_args() +train_embs = [] +train_idx_to_audioid = [] +with File(args.train_feature, "r") as store: + for audio_id, embedding in tqdm(store.items(), ascii=True): + train_embs.append(embedding[()]) + train_idx_to_audioid.append(audio_id) + +train_annotation = json.load(open(args.train_corpus, "r"))["audios"] +train_audioid_to_tokens = {} +for item in train_annotation: + audio_id = item["audio_id"] + train_audioid_to_tokens[audio_id] = [cap_item["tokens"] for cap_item in item["captions"]] +train_embs = np.stack(train_embs) + + +pred_data = [] +pred_embs = [] +pred_idx_to_audioids = [] +with File(args.pred_feature, "r") as store: + for audio_id, embedding in tqdm(store.items(), ascii=True): + pred_embs.append(embedding[()]) + pred_idx_to_audioids.append(audio_id) +pred_embs = np.stack(pred_embs) + +similarity = sklearn.metrics.pairwise.cosine_similarity(pred_embs, train_embs) +for idx, audio_id in enumerate(pred_idx_to_audioids): + train_idx = similarity[idx].argmax() + pred_data.append({ + "filename": audio_id, + "tokens": random.choice(train_audioid_to_tokens[train_idx_to_audioid[train_idx]]) + }) +json.dump({"predictions": pred_data}, open(args.output_json, "w"), ensure_ascii=False, indent=4) diff --git a/audio_to_text/captioning/utils/remove_optimizer.py b/audio_to_text/captioning/utils/remove_optimizer.py new file mode 100644 index 0000000..2b9871e --- /dev/null +++ b/audio_to_text/captioning/utils/remove_optimizer.py @@ -0,0 +1,18 @@ +import argparse +import torch + + +def main(checkpoint): + state_dict = torch.load(checkpoint, map_location="cpu") + if "optimizer" in state_dict: + del state_dict["optimizer"] + if "lr_scheduler" in state_dict: + del state_dict["lr_scheduler"] + torch.save(state_dict, checkpoint) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint", type=str) + args = parser.parse_args() + main(args.checkpoint) diff --git a/audio_to_text/captioning/utils/report_results.py b/audio_to_text/captioning/utils/report_results.py new file mode 100644 index 0000000..3b9f6ec --- /dev/null +++ b/audio_to_text/captioning/utils/report_results.py @@ -0,0 +1,37 @@ +from pathlib import Path +import argparse +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument("--input", help="input filename", type=str, nargs="+") +parser.add_argument("--output", help="output result file", default=None) + +args = parser.parse_args() + + +scores = {} +for path in args.input: + with open(path, "r") as reader: + for line in reader.readlines(): + metric, score = line.strip().split(": ") + score = float(score) + if metric not in scores: + scores[metric] = [] + scores[metric].append(score) + +if len(scores) == 0: + print("No experiment directory found, wrong path?") + exit(1) + +with open(args.output, "w") as writer: + print("Average results: ", file=writer) + for metric, score in scores.items(): + score = np.array(score) + mean = np.mean(score) + std = np.std(score) + print(f"{metric}: {mean:.3f} (±{std:.3f})", file=writer) + print("", file=writer) + print("Best results: ", file=writer) + for metric, score in scores.items(): + score = np.max(score) + print(f"{metric}: {score:.3f}", file=writer) diff --git a/audio_to_text/captioning/utils/tokenize_caption.py b/audio_to_text/captioning/utils/tokenize_caption.py new file mode 100644 index 0000000..b340068 --- /dev/null +++ b/audio_to_text/captioning/utils/tokenize_caption.py @@ -0,0 +1,86 @@ +import json +from tqdm import tqdm +import re +import fire + + +def tokenize_caption(input_json: str, + keep_punctuation: bool = False, + host_address: str = None, + character_level: bool = False, + zh: bool = True, + output_json: str = None): + """Build vocabulary from csv file with a given threshold to drop all counts < threshold + + Args: + input_json(string): Preprossessed json file. Structure like this: + { + 'audios': [ + { + 'audio_id': 'xxx', + 'captions': [ + { + 'caption': 'xxx', + 'cap_id': 'xxx' + } + ] + }, + ... + ] + } + threshold (int): Threshold to drop all words with counts < threshold + keep_punctuation (bool): Includes or excludes punctuation. + + Returns: + vocab (Vocab): Object with the processed vocabulary +""" + data = json.load(open(input_json, "r"))["audios"] + + if zh: + from nltk.parse.corenlp import CoreNLPParser + from zhon.hanzi import punctuation + parser = CoreNLPParser(host_address) + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + for cap_idx in range(len(data[audio_idx]["captions"])): + caption = data[audio_idx]["captions"][cap_idx]["caption"] + # Remove all punctuations + if not keep_punctuation: + caption = re.sub("[{}]".format(punctuation), "", caption) + if character_level: + tokens = list(caption) + else: + tokens = list(parser.tokenize(caption)) + data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens) + else: + from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer + captions = {} + for audio_idx in range(len(data)): + audio_id = data[audio_idx]["audio_id"] + captions[audio_id] = [] + for cap_idx in range(len(data[audio_idx]["captions"])): + caption = data[audio_idx]["captions"][cap_idx]["caption"] + captions[audio_id].append({ + "audio_id": audio_id, + "id": cap_idx, + "caption": caption + }) + tokenizer = PTBTokenizer() + captions = tokenizer.tokenize(captions) + for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): + audio_id = data[audio_idx]["audio_id"] + for cap_idx in range(len(data[audio_idx]["captions"])): + tokens = captions[audio_id][cap_idx] + data[audio_idx]["captions"][cap_idx]["tokens"] = tokens + + if output_json: + json.dump( + { "audios": data }, open(output_json, "w"), + indent=4, ensure_ascii=not zh) + else: + json.dump( + { "audios": data }, open(input_json, "w"), + indent=4, ensure_ascii=not zh) + + +if __name__ == "__main__": + fire.Fire(tokenize_caption) diff --git a/audio_to_text/captioning/utils/train_util.py b/audio_to_text/captioning/utils/train_util.py new file mode 100644 index 0000000..6cd62cc --- /dev/null +++ b/audio_to_text/captioning/utils/train_util.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +#!/usr/bin/env python3 +import os +import sys +import logging +from typing import Callable, Dict, Union +import yaml +import torch +from torch.optim.swa_utils import AveragedModel as torch_average_model +import numpy as np +import pandas as pd +from pprint import pformat + + +def load_dict_from_csv(csv, cols): + df = pd.read_csv(csv, sep="\t") + output = dict(zip(df[cols[0]], df[cols[1]])) + return output + + +def init_logger(filename, level="INFO"): + formatter = logging.Formatter( + "[ %(levelname)s : %(asctime)s ] - %(message)s") + logger = logging.getLogger(__name__ + "." + filename) + logger.setLevel(getattr(logging, level)) + # Log results to std + # stdhandler = logging.StreamHandler(sys.stdout) + # stdhandler.setFormatter(formatter) + # Dump log to file + filehandler = logging.FileHandler(filename) + filehandler.setFormatter(formatter) + logger.addHandler(filehandler) + # logger.addHandler(stdhandler) + return logger + + +def init_obj(module, config, **kwargs):# 'captioning.models.encoder' + obj_args = config["args"].copy() + obj_args.update(kwargs) + return getattr(module, config["type"])(**obj_args) + + +def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'): + """pprint_dict + + :param outputfun: function to use, defaults to sys.stdout + :param in_dict: dict to print + """ + if formatter == 'yaml': + format_fun = yaml.dump + elif formatter == 'pretty': + format_fun = pformat + for line in format_fun(in_dict).split('\n'): + outputfun(line) + + +def merge_a_into_b(a, b): + # merge dict a into dict b. values in a will overwrite b. + for k, v in a.items(): + if isinstance(v, dict) and k in b: + assert isinstance( + b[k], dict + ), "Cannot inherit key '{}' from base!".format(k) + merge_a_into_b(v, b[k]) + else: + b[k] = v + + +def load_config(config_file): + with open(config_file, "r") as reader: + config = yaml.load(reader, Loader=yaml.FullLoader) + if "inherit_from" in config: + base_config_file = config["inherit_from"] + base_config_file = os.path.join( + os.path.dirname(config_file), base_config_file + ) + assert not os.path.samefile(config_file, base_config_file), \ + "inherit from itself" + base_config = load_config(base_config_file) + del config["inherit_from"] + merge_a_into_b(config, base_config) + return base_config + return config + + +def parse_config_or_kwargs(config_file, **kwargs): + yaml_config = load_config(config_file) + # passed kwargs will override yaml config + args = dict(yaml_config, **kwargs) + return args + + +def store_yaml(config, config_file): + with open(config_file, "w") as con_writer: + yaml.dump(config, con_writer, indent=4, default_flow_style=False) + + +class MetricImprover: + + def __init__(self, mode): + assert mode in ("min", "max") + self.mode = mode + # min: lower -> better; max: higher -> better + self.best_value = np.inf if mode == "min" else -np.inf + + def compare(self, x, best_x): + return x < best_x if self.mode == "min" else x > best_x + + def __call__(self, x): + if self.compare(x, self.best_value): + self.best_value = x + return True + return False + + def state_dict(self): + return self.__dict__ + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + +def fix_batchnorm(model: torch.nn.Module): + def inner(module): + class_name = module.__class__.__name__ + if class_name.find("BatchNorm") != -1: + module.eval() + model.apply(inner) + + +def load_pretrained_model(model: torch.nn.Module, + pretrained: Union[str, Dict], + output_fn: Callable = sys.stdout.write): + if not isinstance(pretrained, dict) and not os.path.exists(pretrained): + output_fn(f"pretrained {pretrained} not exist!") + return + + if hasattr(model, "load_pretrained"): + model.load_pretrained(pretrained) + return + + if isinstance(pretrained, dict): + state_dict = pretrained + else: + state_dict = torch.load(pretrained, map_location="cpu") + + if "model" in state_dict: + state_dict = state_dict["model"] + model_dict = model.state_dict() + pretrained_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and ( + model_dict[k].shape == v.shape) + } + output_fn(f"Loading pretrained keys {pretrained_dict.keys()}") + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict, strict=True) + + +class AveragedModel(torch_average_model): + + def update_parameters(self, model): + for p_swa, p_model in zip(self.parameters(), model.parameters()): + device = p_swa.device + p_model_ = p_model.detach().to(device) + if self.n_averaged == 0: + p_swa.detach().copy_(p_model_) + else: + p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, + self.n_averaged.to(device))) + + for b_swa, b_model in zip(list(self.buffers())[1:], model.buffers()): + device = b_swa.device + b_model_ = b_model.detach().to(device) + if self.n_averaged == 0: + b_swa.detach().copy_(b_model_) + else: + b_swa.detach().copy_(self.avg_fn(b_swa.detach(), b_model_, + self.n_averaged.to(device))) + self.n_averaged += 1 diff --git a/audio_to_text/captioning/utils/word2vec/create_word_embedding.py b/audio_to_text/captioning/utils/word2vec/create_word_embedding.py new file mode 100644 index 0000000..77ebe5a --- /dev/null +++ b/audio_to_text/captioning/utils/word2vec/create_word_embedding.py @@ -0,0 +1,67 @@ +# coding=utf-8 +#!/usr/bin/env python3 + +import numpy as np +import pandas as pd +import torch +import gensim +from gensim.models import Word2Vec +from tqdm import tqdm +import fire + +import sys +import os +sys.path.append(os.getcwd()) +from utils.build_vocab import Vocabulary + +def create_embedding(vocab_file: str, + embed_size: int, + output: str, + caption_file: str = None, + pretrained_weights_path: str = None, + **word2vec_kwargs): + vocabulary = torch.load(vocab_file, map_location="cpu") + + if pretrained_weights_path: + model = gensim.models.KeyedVectors.load_word2vec_format( + fname=pretrained_weights_path, + binary=True, + ) + if model.vector_size != embed_size: + assert embed_size < model.vector_size, f"only reduce dimension, cannot add dimesion {model.vector_size} to {embed_size}" + from sklearn.decomposition import PCA + pca = PCA(n_components=embed_size) + model.vectors = pca.fit_transform(model.vectors) + else: + caption_df = pd.read_json(caption_file) + caption_df["tokens"] = caption_df["tokens"].apply(lambda x: [""] + [token for token in x] + [""]) + sentences = list(caption_df["tokens"].values) + epochs = word2vec_kwargs.get("epochs", 10) + if "epochs" in word2vec_kwargs: + del word2vec_kwargs["epochs"] + model = Word2Vec(size=embed_size, min_count=1, **word2vec_kwargs) + model.build_vocab(sentences=sentences) + model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs) + + word_embeddings = np.random.randn(len(vocabulary), embed_size) + + if isinstance(model, gensim.models.word2vec.Word2Vec): + model = model.wv + with tqdm(total=len(vocabulary), ascii=True) as pbar: + for word, idx in vocabulary.word2idx.items(): + try: + word_embeddings[idx] = model.get_vector(word) + except KeyError: + print(f"word {word} not found in word2vec model, it is random initialized!") + pbar.update() + + np.save(output, word_embeddings) + + print("Finish writing word2vec embeddings to " + output) + + +if __name__ == "__main__": + fire.Fire(create_embedding) + + + diff --git a/audio_to_text/inference_waveform.py b/audio_to_text/inference_waveform.py new file mode 100644 index 0000000..aba3961 --- /dev/null +++ b/audio_to_text/inference_waveform.py @@ -0,0 +1,102 @@ +import sys +import os +import librosa +import numpy as np +import torch +import audio_to_text.captioning.models +import audio_to_text.captioning.models.encoder +import audio_to_text.captioning.models.decoder +import audio_to_text.captioning.utils.train_util as train_util + + +def load_model(config, checkpoint): + ckpt = torch.load(checkpoint, "cpu") + encoder_cfg = config["model"]["encoder"] + encoder = train_util.init_obj( + audio_to_text.captioning.models.encoder, + encoder_cfg + ) + if "pretrained" in encoder_cfg: + pretrained = encoder_cfg["pretrained"] + train_util.load_pretrained_model(encoder, + pretrained, + sys.stdout.write) + decoder_cfg = config["model"]["decoder"] + if "vocab_size" not in decoder_cfg["args"]: + decoder_cfg["args"]["vocab_size"] = len(ckpt["vocabulary"]) + decoder = train_util.init_obj( + audio_to_text.captioning.models.decoder, + decoder_cfg + ) + if "word_embedding" in decoder_cfg: + decoder.load_word_embedding(**decoder_cfg["word_embedding"]) + if "pretrained" in decoder_cfg: + pretrained = decoder_cfg["pretrained"] + train_util.load_pretrained_model(decoder, + pretrained, + sys.stdout.write) + model = train_util.init_obj(audio_to_text.captioning.models, config["model"], + encoder=encoder, decoder=decoder) + train_util.load_pretrained_model(model, ckpt) + model.eval() + return { + "model": model, + "vocabulary": ckpt["vocabulary"] + } + + +def decode_caption(word_ids, vocabulary): + candidate = [] + for word_id in word_ids: + word = vocabulary[word_id] + if word == "": + break + elif word == "": + continue + candidate.append(word) + candidate = " ".join(candidate) + return candidate + + +class AudioCapModel(object): + def __init__(self,weight_dir,device='cuda'): + config = os.path.join(weight_dir,'config.yaml') + self.config = train_util.parse_config_or_kwargs(config) + checkpoint = os.path.join(weight_dir,'swa.pth') + resumed = load_model(self.config, checkpoint) + model = resumed["model"] + self.vocabulary = resumed["vocabulary"] + self.model = model.to(device) + self.device = device + + def caption(self,audio_list): + if isinstance(audio_list,np.ndarray): + audio_list = [audio_list] + elif isinstance(audio_list,str): + audio_list = [librosa.load(audio_list,sr=32000)[0]] + + captions = [] + for wav in audio_list: + inputwav = torch.as_tensor(wav).float().unsqueeze(0).to(self.device) + wav_len = torch.LongTensor([len(wav)]) + input_dict = { + "mode": "inference", + "wav": inputwav, + "wav_len": wav_len, + "specaug": False, + "sample_method": "beam", + } + print(input_dict) + out_dict = self.model(input_dict) + caption_batch = [decode_caption(seq, self.vocabulary) for seq in \ + out_dict["seq"].cpu().numpy()] + captions.extend(caption_batch) + return captions + + + + def __call__(self, audio_list): + return self.caption(audio_list) + + + diff --git a/download.sh b/download.sh index 8de681a..e6be9cc 100644 --- a/download.sh +++ b/download.sh @@ -26,4 +26,9 @@ wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/Gene wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/spk_map.json wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/train_f0s_mean_std.npy wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/word_set.json - +wget -P text_to_speech/checkpoints/hifi_lj -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/text_to_speech/checkpoints/hifi_lj/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/text_to_speech/checkpoints/hifi_lj/model_ckpt_steps_2076000.ckpt +wget -P text_to_speech/checkpoints/ljspeech/ps_adv_baseline -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/text_to_speech/checkpoints/ljspeech/ps_adv_baseline/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/checkpoints/ljspeech/ps_adv_baseline/model_ckpt_steps_160000.ckpt https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/checkpoints/ljspeech/ps_adv_baseline/model_ckpt_steps_160001.ckpt +# Audio to text +wget -P audio_to_text/audiocaps_cntrstv_cnn14rnn_trm -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/swa.pth +wget -P audio_to_text/clotho_cntrstv_cnn14rnn_trm -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/audio_to_text/clotho_cntrstv_cnn14rnn_trm/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/clotho_cntrstv_cnn14rnn_trm/swa.pth +wget -P audio_to_text/pretrained_feature_extractors https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth \ No newline at end of file