diff --git a/assets/2bf90e35.wav b/assets/2bf90e35.wav
new file mode 100644
index 0000000..1230dad
Binary files /dev/null and b/assets/2bf90e35.wav differ
diff --git a/assets/5d67d1b9.wav b/assets/5d67d1b9.wav
new file mode 100644
index 0000000..46beb7e
Binary files /dev/null and b/assets/5d67d1b9.wav differ
diff --git a/assets/README.md b/assets/README.md
index 7c83ea3..1cf23e0 100644
--- a/assets/README.md
+++ b/assets/README.md
@@ -7,21 +7,45 @@ Output:
Input Example : Generate an audio of a piano playing
Output:

+Audio:
+
+
+## Text-To-Speech
+Input Example : Generate a speech with text "here we go"
+Output:
+
+Audio:
+
+
## Text-To-Sing
Input example : please generate a piece of singing voice. Text sequence is 小酒窝长睫毛AP是你最美的记号. Note sequence is C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4. Note duration sequence is 0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340.
Output:

+Audio:
+
## Image-To-Audio
First upload your image(.png)
Input Example : Generate the audio of this image
Output:

-## ASR
+Audio:
+
+
+## Speech Recognition
First upload your audio(.wav)
-Input Example : Generate the text of this audio
+Audio Example :
+
+Input Example : Generate the text of this speech
Output:

+## Audio-To-Text
+First upload your audio(.wav)
+Audio Example :
+
+Input Example : Please tell me the text description of this audio.
+Output:
+
## Style Transfer Text-To-Speech
First upload your audio(.wav)
Input Example : Speak using the voice of this audio. The text is "here we go".
diff --git a/assets/Track 4.wav b/assets/Track 4.wav
new file mode 100644
index 0000000..64a19a1
Binary files /dev/null and b/assets/Track 4.wav differ
diff --git a/assets/a-group-of-sheep-are-baaing.wav b/assets/a-group-of-sheep-are-baaing.wav
new file mode 100644
index 0000000..2c5300a
Binary files /dev/null and b/assets/a-group-of-sheep-are-baaing.wav differ
diff --git a/assets/a2i.png b/assets/a2i.png
new file mode 100644
index 0000000..cbca66b
Binary files /dev/null and b/assets/a2i.png differ
diff --git a/assets/b973e878.wav b/assets/b973e878.wav
new file mode 100644
index 0000000..8abd02d
Binary files /dev/null and b/assets/b973e878.wav differ
diff --git a/assets/fd5cf55e.wav b/assets/fd5cf55e.wav
new file mode 100644
index 0000000..bbc6c76
Binary files /dev/null and b/assets/fd5cf55e.wav differ
diff --git a/assets/tts.png b/assets/tts.png
new file mode 100644
index 0000000..871f27a
Binary files /dev/null and b/assets/tts.png differ
diff --git a/audio_to_text/__init__.py b/audio_to_text/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/audio_to_text/captioning/__init__.py b/audio_to_text/captioning/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/audio_to_text/captioning/models/__init__.py b/audio_to_text/captioning/models/__init__.py
new file mode 100644
index 0000000..7259d67
--- /dev/null
+++ b/audio_to_text/captioning/models/__init__.py
@@ -0,0 +1,3 @@
+from .base_model import *
+from .transformer_model import *
+
diff --git a/audio_to_text/captioning/models/base_model.py b/audio_to_text/captioning/models/base_model.py
new file mode 100644
index 0000000..cd014e9
--- /dev/null
+++ b/audio_to_text/captioning/models/base_model.py
@@ -0,0 +1,500 @@
+# -*- coding: utf-8 -*-
+
+from typing import Dict
+
+import torch
+import torch.nn as nn
+
+from .utils import mean_with_lens, repeat_tensor
+
+
+class CaptionModel(nn.Module):
+ """
+ Encoder-decoder captioning model.
+ """
+
+ pad_idx = 0
+ start_idx = 1
+ end_idx = 2
+ max_length = 20
+
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ self.vocab_size = decoder.vocab_size
+ self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
+ self.inference_forward_keys = ["sample_method", "max_length", "temp"]
+ freeze_encoder = kwargs.get("freeze_encoder", False)
+ if freeze_encoder:
+ for param in self.encoder.parameters():
+ param.requires_grad = False
+ self.check_decoder_compatibility()
+
+ def check_decoder_compatibility(self):
+ compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
+ assert isinstance(self.decoder, self.compatible_decoders), \
+ f"{self.decoder.__class__.__name__} is incompatible with " \
+ f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
+
+ @classmethod
+ def set_index(cls, start_idx, end_idx):
+ cls.start_idx = start_idx
+ cls.end_idx = end_idx
+
+ def forward(self, input_dict: Dict):
+ """
+ input_dict: {
+ (required)
+ mode: train/inference,
+ spec,
+ spec_len,
+ fc,
+ attn,
+ attn_len,
+ [sample_method: greedy],
+ [temp: 1.0] (in case of no teacher forcing)
+
+ (optional, mode=train)
+ cap,
+ cap_len,
+ ss_ratio,
+
+ (optional, mode=inference)
+ sample_method: greedy/beam,
+ max_length,
+ temp,
+ beam_size (optional, sample_method=beam),
+ n_best (optional, sample_method=beam),
+ }
+ """
+ # encoder_input_keys = ["spec", "spec_len", "fc", "attn", "attn_len"]
+ # encoder_input = { key: input_dict[key] for key in encoder_input_keys }
+ encoder_output_dict = self.encoder(input_dict)
+ if input_dict["mode"] == "train":
+ forward_dict = {
+ "mode": "train", "sample_method": "greedy", "temp": 1.0
+ }
+ for key in self.train_forward_keys:
+ forward_dict[key] = input_dict[key]
+ forward_dict.update(encoder_output_dict)
+ output = self.train_forward(forward_dict)
+ elif input_dict["mode"] == "inference":
+ forward_dict = {"mode": "inference"}
+ default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
+ for key in self.inference_forward_keys:
+ if key in input_dict:
+ forward_dict[key] = input_dict[key]
+ else:
+ forward_dict[key] = default_args[key]
+
+ if forward_dict["sample_method"] == "beam":
+ forward_dict["beam_size"] = input_dict.get("beam_size", 3)
+ forward_dict["n_best"] = input_dict.get("n_best", False)
+ forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
+ elif forward_dict["sample_method"] == "dbs":
+ forward_dict["beam_size"] = input_dict.get("beam_size", 6)
+ forward_dict["group_size"] = input_dict.get("group_size", 3)
+ forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
+ forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
+
+ forward_dict.update(encoder_output_dict)
+ output = self.inference_forward(forward_dict)
+ else:
+ raise Exception("mode should be either 'train' or 'inference'")
+
+ return output
+
+ def prepare_output(self, input_dict):
+ output = {}
+ batch_size = input_dict["fc_emb"].size(0)
+ if input_dict["mode"] == "train":
+ max_length = input_dict["cap"].size(1) - 1
+ elif input_dict["mode"] == "inference":
+ max_length = input_dict["max_length"]
+ else:
+ raise Exception("mode should be either 'train' or 'inference'")
+ device = input_dict["fc_emb"].device
+ output["seq"] = torch.full((batch_size, max_length), self.end_idx,
+ dtype=torch.long)
+ output["logit"] = torch.empty(batch_size, max_length,
+ self.vocab_size).to(device)
+ output["sampled_logprob"] = torch.zeros(batch_size, max_length)
+ output["embed"] = torch.empty(batch_size, max_length,
+ self.decoder.d_model).to(device)
+ return output
+
+ def train_forward(self, input_dict):
+ if input_dict["ss_ratio"] != 1: # scheduled sampling training
+ input_dict["mode"] = "train"
+ return self.stepwise_forward(input_dict)
+ output = self.seq_forward(input_dict)
+ self.train_process(output, input_dict)
+ return output
+
+ def seq_forward(self, input_dict):
+ raise NotImplementedError
+
+ def train_process(self, output, input_dict):
+ pass
+
+ def inference_forward(self, input_dict):
+ if input_dict["sample_method"] == "beam":
+ return self.beam_search(input_dict)
+ elif input_dict["sample_method"] == "dbs":
+ return self.diverse_beam_search(input_dict)
+ return self.stepwise_forward(input_dict)
+
+ def stepwise_forward(self, input_dict):
+ """Step-by-step decoding"""
+ output = self.prepare_output(input_dict)
+ max_length = output["seq"].size(1)
+ # start sampling
+ for t in range(max_length):
+ input_dict["t"] = t
+ self.decode_step(input_dict, output)
+ if input_dict["mode"] == "inference": # decide whether to stop when sampling
+ unfinished_t = output["seq"][:, t] != self.end_idx
+ if t == 0:
+ unfinished = unfinished_t
+ else:
+ unfinished *= unfinished_t
+ output["seq"][:, t][~unfinished] = self.end_idx
+ if unfinished.sum() == 0:
+ break
+ self.stepwise_process(output)
+ return output
+
+ def decode_step(self, input_dict, output):
+ """Decoding operation of timestep t"""
+ decoder_input = self.prepare_decoder_input(input_dict, output)
+ # feed to the decoder to get logit
+ output_t = self.decoder(decoder_input)
+ logit_t = output_t["logit"]
+ # assert logit_t.ndim == 3
+ if logit_t.size(1) == 1:
+ logit_t = logit_t.squeeze(1)
+ embed_t = output_t["embed"].squeeze(1)
+ elif logit_t.size(1) > 1:
+ logit_t = logit_t[:, -1, :]
+ embed_t = output_t["embed"][:, -1, :]
+ else:
+ raise Exception("no logit output")
+ # sample the next input word and get the corresponding logit
+ sampled = self.sample_next_word(logit_t,
+ method=input_dict["sample_method"],
+ temp=input_dict["temp"])
+
+ output_t.update(sampled)
+ output_t["t"] = input_dict["t"]
+ output_t["logit"] = logit_t
+ output_t["embed"] = embed_t
+ self.stepwise_process_step(output, output_t)
+
+ def prepare_decoder_input(self, input_dict, output):
+ """Prepare the inp ut dict for the decoder"""
+ raise NotImplementedError
+
+ def stepwise_process_step(self, output, output_t):
+ """Postprocessing (save output values) after each timestep t"""
+ t = output_t["t"]
+ output["logit"][:, t, :] = output_t["logit"]
+ output["seq"][:, t] = output_t["word"]
+ output["sampled_logprob"][:, t] = output_t["probs"]
+ output["embed"][:, t, :] = output_t["embed"]
+
+ def stepwise_process(self, output):
+ """Postprocessing after the whole step-by-step autoregressive decoding"""
+ pass
+
+ def sample_next_word(self, logit, method, temp):
+ """Sample the next word, given probs output by the decoder"""
+ logprob = torch.log_softmax(logit, dim=1)
+ if method == "greedy":
+ sampled_logprob, word = torch.max(logprob.detach(), 1)
+ elif method == "gumbel":
+ def sample_gumbel(shape, eps=1e-20):
+ U = torch.rand(shape).to(logprob.device)
+ return -torch.log(-torch.log(U + eps) + eps)
+ def gumbel_softmax_sample(logit, temperature):
+ y = logit + sample_gumbel(logit.size())
+ return torch.log_softmax(y / temperature, dim=-1)
+ _logprob = gumbel_softmax_sample(logprob, temp)
+ _, word = torch.max(_logprob.data, 1)
+ sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
+ else:
+ logprob = logprob / temp
+ if method.startswith("top"):
+ top_num = float(method[3:])
+ if 0 < top_num < 1: # top-p sampling
+ probs = torch.softmax(logit, dim=1)
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
+ _cumsum = sorted_probs.cumsum(1)
+ mask = _cumsum < top_num
+ mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
+ sorted_probs = sorted_probs * mask.to(sorted_probs)
+ sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
+ logprob.scatter_(1, sorted_indices, sorted_probs.log())
+ else: # top-k sampling
+ k = int(top_num)
+ tmp = torch.empty_like(logprob).fill_(float('-inf'))
+ topk, indices = torch.topk(logprob, k, dim=1)
+ tmp = tmp.scatter(1, indices, topk)
+ logprob = tmp
+ word = torch.distributions.Categorical(logits=logprob.detach()).sample()
+ sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
+ word = word.detach().long()
+ # sampled_logprob: [N,], word: [N,]
+ return {"word": word, "probs": sampled_logprob}
+
+ def beam_search(self, input_dict):
+ output = self.prepare_output(input_dict)
+ max_length = input_dict["max_length"]
+ beam_size = input_dict["beam_size"]
+ if input_dict["n_best"]:
+ n_best_size = input_dict["n_best_size"]
+ batch_size, max_length = output["seq"].size()
+ output["seq"] = torch.full((batch_size, n_best_size, max_length),
+ self.end_idx, dtype=torch.long)
+
+ temp = input_dict["temp"]
+ # instance by instance beam seach
+ for i in range(output["seq"].size(0)):
+ output_i = self.prepare_beamsearch_output(input_dict)
+ input_dict["sample_idx"] = i
+ for t in range(max_length):
+ input_dict["t"] = t
+ output_t = self.beamsearch_step(input_dict, output_i)
+ #######################################
+ # merge with previous beam and select the current max prob beam
+ #######################################
+ logit_t = output_t["logit"]
+ if logit_t.size(1) == 1:
+ logit_t = logit_t.squeeze(1)
+ elif logit_t.size(1) > 1:
+ logit_t = logit_t[:, -1, :]
+ else:
+ raise Exception("no logit output")
+ logprob_t = torch.log_softmax(logit_t, dim=1)
+ logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
+ logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
+ if t == 0: # for the first step, all k seq will have the same probs
+ topk_logprob, topk_words = logprob_t[0].topk(
+ beam_size, 0, True, True)
+ else: # unroll and find top logprob, and their unrolled indices
+ topk_logprob, topk_words = logprob_t.view(-1).topk(
+ beam_size, 0, True, True)
+ topk_words = topk_words.cpu()
+ output_i["topk_logprob"] = topk_logprob
+ # output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
+ output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
+ rounding_mode='trunc')
+ output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
+ if t == 0:
+ output_i["seq"] = output_i["next_word"].unsqueeze(1)
+ else:
+ output_i["seq"] = torch.cat([
+ output_i["seq"][output_i["prev_words_beam"]],
+ output_i["next_word"].unsqueeze(1)], dim=1)
+
+ # add finished beams to results
+ is_end = output_i["next_word"] == self.end_idx
+ if t == max_length - 1:
+ is_end.fill_(1)
+
+ for beam_idx in range(beam_size):
+ if is_end[beam_idx]:
+ final_beam = {
+ "seq": output_i["seq"][beam_idx].clone(),
+ "score": output_i["topk_logprob"][beam_idx].item()
+ }
+ final_beam["score"] = final_beam["score"] / (t + 1)
+ output_i["done_beams"].append(final_beam)
+ output_i["topk_logprob"][is_end] -= 1000
+
+ self.beamsearch_process_step(output_i, output_t)
+
+ self.beamsearch_process(output, output_i, input_dict)
+ return output
+
+ def prepare_beamsearch_output(self, input_dict):
+ beam_size = input_dict["beam_size"]
+ device = input_dict["fc_emb"].device
+ output = {
+ "topk_logprob": torch.zeros(beam_size).to(device),
+ "seq": None,
+ "prev_words_beam": None,
+ "next_word": None,
+ "done_beams": [],
+ }
+ return output
+
+ def beamsearch_step(self, input_dict, output_i):
+ decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
+ output_t = self.decoder(decoder_input)
+ output_t["t"] = input_dict["t"]
+ return output_t
+
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
+ raise NotImplementedError
+
+ def beamsearch_process_step(self, output_i, output_t):
+ pass
+
+ def beamsearch_process(self, output, output_i, input_dict):
+ i = input_dict["sample_idx"]
+ done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
+ if input_dict["n_best"]:
+ done_beams = done_beams[:input_dict["n_best_size"]]
+ for out_idx, done_beam in enumerate(done_beams):
+ seq = done_beam["seq"]
+ output["seq"][i][out_idx, :len(seq)] = seq
+ else:
+ seq = done_beams[0]["seq"]
+ output["seq"][i][:len(seq)] = seq
+
+ def diverse_beam_search(self, input_dict):
+
+ def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
+ local_time = t - divm
+ unaug_logprob = logprob.clone()
+
+ if divm > 0:
+ change = torch.zeros(logprob.size(-1))
+ for prev_choice in range(divm):
+ prev_decisions = seq_table[prev_choice][..., local_time]
+ for prev_labels in range(bdash):
+ change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
+
+ change = change.to(logprob.device)
+ logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
+
+ return logprob, unaug_logprob
+
+ output = self.prepare_output(input_dict)
+ group_size = input_dict["group_size"]
+ batch_size = output["seq"].size(0)
+ beam_size = input_dict["beam_size"]
+ bdash = beam_size // group_size
+ input_dict["bdash"] = bdash
+ diversity_lambda = input_dict["diversity_lambda"]
+ device = input_dict["fc_emb"].device
+ max_length = input_dict["max_length"]
+ temp = input_dict["temp"]
+ group_nbest = input_dict["group_nbest"]
+ batch_size, max_length = output["seq"].size()
+ if group_nbest:
+ output["seq"] = torch.full((batch_size, beam_size, max_length),
+ self.end_idx, dtype=torch.long)
+ else:
+ output["seq"] = torch.full((batch_size, group_size, max_length),
+ self.end_idx, dtype=torch.long)
+
+
+ for i in range(batch_size):
+ input_dict["sample_idx"] = i
+ seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
+ logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
+ done_beams_table = [[] for _ in range(group_size)]
+
+ output_i = {
+ "prev_words_beam": [None for _ in range(group_size)],
+ "next_word": [None for _ in range(group_size)],
+ "state": [None for _ in range(group_size)]
+ }
+
+ for t in range(max_length + group_size - 1):
+ input_dict["t"] = t
+ for divm in range(group_size):
+ input_dict["divm"] = divm
+ if t >= divm and t <= max_length + divm - 1:
+ local_time = t - divm
+ decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
+ output_t = self.decoder(decoder_input)
+ output_t["divm"] = divm
+ logit_t = output_t["logit"]
+ if logit_t.size(1) == 1:
+ logit_t = logit_t.squeeze(1)
+ elif logit_t.size(1) > 1:
+ logit_t = logit_t[:, -1, :]
+ else:
+ raise Exception("no logit output")
+ logprob_t = torch.log_softmax(logit_t, dim=1)
+ logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
+ logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
+ logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
+ if local_time == 0: # for the first step, all k seq will have the same probs
+ topk_logprob, topk_words = logprob_t[0].topk(
+ bdash, 0, True, True)
+ else: # unroll and find top logprob, and their unrolled indices
+ topk_logprob, topk_words = logprob_t.view(-1).topk(
+ bdash, 0, True, True)
+ topk_words = topk_words.cpu()
+ logprob_table[divm] = topk_logprob
+ output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
+ output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
+ if local_time > 0:
+ seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
+ seq_table[divm] = torch.cat([
+ seq_table[divm],
+ output_i["next_word"][divm].unsqueeze(-1)], -1)
+
+ is_end = seq_table[divm][:, t-divm] == self.end_idx
+ assert seq_table[divm].shape[-1] == t - divm + 1
+ if t == max_length + divm - 1:
+ is_end.fill_(1)
+ for beam_idx in range(bdash):
+ if is_end[beam_idx]:
+ final_beam = {
+ "seq": seq_table[divm][beam_idx].clone(),
+ "score": logprob_table[divm][beam_idx].item()
+ }
+ final_beam["score"] = final_beam["score"] / (t - divm + 1)
+ done_beams_table[divm].append(final_beam)
+ logprob_table[divm][is_end] -= 1000
+ self.dbs_process_step(output_i, output_t)
+ done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
+ if group_nbest:
+ done_beams = sum(done_beams_table, [])
+ else:
+ done_beams = [group_beam[0] for group_beam in done_beams_table]
+ for _, done_beam in enumerate(done_beams):
+ output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
+
+ return output
+
+ def prepare_dbs_decoder_input(self, input_dict, output_i):
+ raise NotImplementedError
+
+ def dbs_process_step(self, output_i, output_t):
+ pass
+
+
+class CaptionSequenceModel(nn.Module):
+
+ def __init__(self, model, seq_output_size):
+ super().__init__()
+ self.model = model
+ if model.decoder.d_model != seq_output_size:
+ self.output_transform = nn.Linear(model.decoder.d_model, seq_output_size)
+ else:
+ self.output_transform = lambda x: x
+
+ def forward(self, input_dict):
+ output = self.model(input_dict)
+
+ if input_dict["mode"] == "train":
+ lens = input_dict["cap_len"] - 1
+ # seq_outputs: [N, d_model]
+ elif input_dict["mode"] == "inference":
+ if "sample_method" in input_dict and input_dict["sample_method"] == "beam":
+ return output
+ seq = output["seq"]
+ lens = torch.where(seq == self.model.end_idx, torch.zeros_like(seq), torch.ones_like(seq)).sum(dim=1)
+ else:
+ raise Exception("mode should be either 'train' or 'inference'")
+ seq_output = mean_with_lens(output["embed"], lens)
+ seq_output = self.output_transform(seq_output)
+ output["seq_output"] = seq_output
+ return output
+
diff --git a/audio_to_text/captioning/models/decoder.py b/audio_to_text/captioning/models/decoder.py
new file mode 100644
index 0000000..869eac1
--- /dev/null
+++ b/audio_to_text/captioning/models/decoder.py
@@ -0,0 +1,746 @@
+# -*- coding: utf-8 -*-
+
+import math
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from .utils import generate_length_mask, init, PositionalEncoding
+
+
+class BaseDecoder(nn.Module):
+ """
+ Take word/audio embeddings and output the next word probs
+ Base decoder, cannot be called directly
+ All decoders should inherit from this class
+ """
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim,
+ attn_emb_dim, dropout=0.2):
+ super().__init__()
+ self.emb_dim = emb_dim
+ self.vocab_size = vocab_size
+ self.fc_emb_dim = fc_emb_dim
+ self.attn_emb_dim = attn_emb_dim
+ self.word_embedding = nn.Embedding(vocab_size, emb_dim)
+ self.in_dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ raise NotImplementedError
+
+ def load_word_embedding(self, weight, freeze=True):
+ embedding = np.load(weight)
+ assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
+ assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
+
+ # embeddings = torch.as_tensor(embeddings).float()
+ # self.word_embeddings.weight = nn.Parameter(embeddings)
+ # for para in self.word_embeddings.parameters():
+ # para.requires_grad = tune
+ self.word_embedding = nn.Embedding.from_pretrained(embedding,
+ freeze=freeze)
+
+
+class RnnDecoder(BaseDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs):
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout,)
+ self.d_model = d_model
+ self.num_layers = kwargs.get('num_layers', 1)
+ self.bidirectional = kwargs.get('bidirectional', False)
+ self.rnn_type = kwargs.get('rnn_type', "GRU")
+ self.classifier = nn.Linear(
+ self.d_model * (self.bidirectional + 1), vocab_size)
+
+ def forward(self, x):
+ raise NotImplementedError
+
+ def init_hidden(self, bs, device):
+ num_dire = self.bidirectional + 1
+ n_layer = self.num_layers
+ hid_dim = self.d_model
+ if self.rnn_type == "LSTM":
+ return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device),
+ torch.zeros(num_dire * n_layer, bs, hid_dim).to(device))
+ else:
+ return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)
+
+
+class RnnFcDecoder(RnnDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs):
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs)
+ self.model = getattr(nn, self.rnn_type)(
+ input_size=self.emb_dim * 2,
+ hidden_size=self.d_model,
+ batch_first=True,
+ num_layers=self.num_layers,
+ bidirectional=self.bidirectional)
+ self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
+ self.apply(init)
+
+ def forward(self, input_dict):
+ word = input_dict["word"]
+ state = input_dict.get("state", None)
+ fc_emb = input_dict["fc_emb"]
+
+ word = word.to(fc_emb.device)
+ embed = self.in_dropout(self.word_embedding(word))
+
+ p_fc_emb = self.fc_proj(fc_emb)
+ # embed: [N, T, embed_size]
+ embed = torch.cat((embed, p_fc_emb), dim=-1)
+
+ out, state = self.model(embed, state)
+ # out: [N, T, hs], states: [num_layers * num_dire, N, hs]
+ logits = self.classifier(out)
+ output = {
+ "state": state,
+ "embeds": out,
+ "logits": logits
+ }
+
+ return output
+
+
+class Seq2SeqAttention(nn.Module):
+
+ def __init__(self, hs_enc, hs_dec, attn_size):
+ """
+ Args:
+ hs_enc: encoder hidden size
+ hs_dec: decoder hidden size
+ attn_size: attention vector size
+ """
+ super(Seq2SeqAttention, self).__init__()
+ self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size)
+ self.v = nn.Parameter(torch.randn(attn_size))
+ self.apply(init)
+
+ def forward(self, h_dec, h_enc, src_lens):
+ """
+ Args:
+ h_dec: decoder hidden (query), [N, hs_dec]
+ h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
+ src_lens: source (encoder memory) lengths, [N, ]
+ """
+ N = h_enc.size(0)
+ src_max_len = h_enc.size(1)
+ h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
+
+ attn_input = torch.cat((h_dec, h_enc), dim=-1)
+ attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
+
+ v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
+ score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
+
+ idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
+ mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
+
+ score = score.masked_fill(mask == 0, -1e10)
+ weights = torch.softmax(score, dim=-1) # [N, src_max_len]
+ ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
+
+ return ctx, weights
+
+
+class AttentionProj(nn.Module):
+
+ def __init__(self, hs_enc, hs_dec, embed_dim, attn_size):
+ self.q_proj = nn.Linear(hs_dec, embed_dim)
+ self.kv_proj = nn.Linear(hs_enc, embed_dim)
+ self.h2attn = nn.Linear(embed_dim * 2, attn_size)
+ self.v = nn.Parameter(torch.randn(attn_size))
+ self.apply(init)
+
+ def init(self, m):
+ if isinstance(m, nn.Linear):
+ nn.init.kaiming_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, h_dec, h_enc, src_lens):
+ """
+ Args:
+ h_dec: decoder hidden (query), [N, hs_dec]
+ h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
+ src_lens: source (encoder memory) lengths, [N, ]
+ """
+ h_enc = self.kv_proj(h_enc) # [N, src_max_len, embed_dim]
+ h_dec = self.q_proj(h_dec) # [N, embed_dim]
+ N = h_enc.size(0)
+ src_max_len = h_enc.size(1)
+ h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
+
+ attn_input = torch.cat((h_dec, h_enc), dim=-1)
+ attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
+
+ v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
+ score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
+
+ idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
+ mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
+
+ score = score.masked_fill(mask == 0, -1e10)
+ weights = torch.softmax(score, dim=-1) # [N, src_max_len]
+ ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
+
+ return ctx, weights
+
+
+class BahAttnDecoder(RnnDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs):
+ """
+ concatenate fc, attn, word to feed to the rnn
+ """
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs)
+ attn_size = kwargs.get("attn_size", self.d_model)
+ self.model = getattr(nn, self.rnn_type)(
+ input_size=self.emb_dim * 3,
+ hidden_size=self.d_model,
+ batch_first=True,
+ num_layers=self.num_layers,
+ bidirectional=self.bidirectional)
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
+ self.d_model * (self.bidirectional + 1) * \
+ self.num_layers,
+ attn_size)
+ self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
+ self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
+ self.apply(init)
+
+ def forward(self, input_dict):
+ word = input_dict["word"]
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
+ fc_emb = input_dict["fc_emb"]
+ attn_emb = input_dict["attn_emb"]
+ attn_emb_len = input_dict["attn_emb_len"]
+
+ word = word.to(fc_emb.device)
+ embed = self.in_dropout(self.word_embedding(word))
+
+ # embed: [N, 1, embed_size]
+ if state is None:
+ state = self.init_hidden(word.size(0), fc_emb.device)
+ if self.rnn_type == "LSTM":
+ query = state[0].transpose(0, 1).flatten(1)
+ else:
+ query = state.transpose(0, 1).flatten(1)
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
+
+ p_fc_emb = self.fc_proj(fc_emb)
+ p_ctx = self.ctx_proj(c)
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)),
+ dim=-1)
+
+ out, state = self.model(rnn_input, state)
+
+ output = {
+ "state": state,
+ "embed": out,
+ "logit": self.classifier(out),
+ "attn_weight": attn_weight
+ }
+ return output
+
+
+class BahAttnDecoder2(RnnDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs):
+ """
+ add fc, attn, word together to feed to the rnn
+ """
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs)
+ attn_size = kwargs.get("attn_size", self.d_model)
+ self.model = getattr(nn, self.rnn_type)(
+ input_size=self.emb_dim,
+ hidden_size=self.d_model,
+ batch_first=True,
+ num_layers=self.num_layers,
+ bidirectional=self.bidirectional)
+ self.attn = Seq2SeqAttention(self.emb_dim,
+ self.d_model * (self.bidirectional + 1) * \
+ self.num_layers,
+ attn_size)
+ self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
+ self.attn_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
+ self.apply(partial(init, method="xavier"))
+
+ def forward(self, input_dict):
+ word = input_dict["word"]
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
+ fc_emb = input_dict["fc_emb"]
+ attn_emb = input_dict["attn_emb"]
+ attn_emb_len = input_dict["attn_emb_len"]
+
+ word = word.to(fc_emb.device)
+ embed = self.in_dropout(self.word_embedding(word))
+ p_attn_emb = self.attn_proj(attn_emb)
+
+ # embed: [N, 1, embed_size]
+ if state is None:
+ state = self.init_hidden(word.size(0), fc_emb.device)
+ if self.rnn_type == "LSTM":
+ query = state[0].transpose(0, 1).flatten(1)
+ else:
+ query = state.transpose(0, 1).flatten(1)
+ c, attn_weight = self.attn(query, p_attn_emb, attn_emb_len)
+
+ p_fc_emb = self.fc_proj(fc_emb)
+ rnn_input = embed + c.unsqueeze(1) + p_fc_emb.unsqueeze(1)
+
+ out, state = self.model(rnn_input, state)
+
+ output = {
+ "state": state,
+ "embed": out,
+ "logit": self.classifier(out),
+ "attn_weight": attn_weight
+ }
+ return output
+
+
+class ConditionalBahAttnDecoder(RnnDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs):
+ """
+ concatenate fc, attn, word to feed to the rnn
+ """
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs)
+ attn_size = kwargs.get("attn_size", self.d_model)
+ self.model = getattr(nn, self.rnn_type)(
+ input_size=self.emb_dim * 3,
+ hidden_size=self.d_model,
+ batch_first=True,
+ num_layers=self.num_layers,
+ bidirectional=self.bidirectional)
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
+ self.d_model * (self.bidirectional + 1) * \
+ self.num_layers,
+ attn_size)
+ self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
+ self.condition_embedding = nn.Embedding(2, emb_dim)
+ self.apply(init)
+
+ def forward(self, input_dict):
+ word = input_dict["word"]
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
+ fc_emb = input_dict["fc_emb"]
+ attn_emb = input_dict["attn_emb"]
+ attn_emb_len = input_dict["attn_emb_len"]
+ condition = input_dict["condition"]
+
+ word = word.to(fc_emb.device)
+ embed = self.in_dropout(self.word_embedding(word))
+
+ condition = torch.as_tensor([[1 - c, c] for c in condition]).to(fc_emb.device)
+ condition_emb = torch.matmul(condition, self.condition_embedding.weight)
+ # condition_embs: [N, emb_dim]
+
+ # embed: [N, 1, embed_size]
+ if state is None:
+ state = self.init_hidden(word.size(0), fc_emb.device)
+ if self.rnn_type == "LSTM":
+ query = state[0].transpose(0, 1).flatten(1)
+ else:
+ query = state.transpose(0, 1).flatten(1)
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
+
+ p_ctx = self.ctx_proj(c)
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), condition_emb.unsqueeze(1)),
+ dim=-1)
+
+ out, state = self.model(rnn_input, state)
+
+ output = {
+ "state": state,
+ "embed": out,
+ "logit": self.classifier(out),
+ "attn_weight": attn_weight
+ }
+ return output
+
+
+class StructBahAttnDecoder(RnnDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, struct_vocab_size,
+ attn_emb_dim, dropout, d_model, **kwargs):
+ """
+ concatenate fc, attn, word to feed to the rnn
+ """
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs)
+ attn_size = kwargs.get("attn_size", self.d_model)
+ self.model = getattr(nn, self.rnn_type)(
+ input_size=self.emb_dim * 3,
+ hidden_size=self.d_model,
+ batch_first=True,
+ num_layers=self.num_layers,
+ bidirectional=self.bidirectional)
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
+ self.d_model * (self.bidirectional + 1) * \
+ self.num_layers,
+ attn_size)
+ self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
+ self.struct_embedding = nn.Embedding(struct_vocab_size, emb_dim)
+ self.apply(init)
+
+ def forward(self, input_dict):
+ word = input_dict["word"]
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
+ fc_emb = input_dict["fc_emb"]
+ attn_emb = input_dict["attn_emb"]
+ attn_emb_len = input_dict["attn_emb_len"]
+ structure = input_dict["structure"]
+
+ word = word.to(fc_emb.device)
+ embed = self.in_dropout(self.word_embedding(word))
+
+ struct_emb = self.struct_embedding(structure)
+ # struct_embs: [N, emb_dim]
+
+ # embed: [N, 1, embed_size]
+ if state is None:
+ state = self.init_hidden(word.size(0), fc_emb.device)
+ if self.rnn_type == "LSTM":
+ query = state[0].transpose(0, 1).flatten(1)
+ else:
+ query = state.transpose(0, 1).flatten(1)
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
+
+ p_ctx = self.ctx_proj(c)
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), struct_emb.unsqueeze(1)), dim=-1)
+
+ out, state = self.model(rnn_input, state)
+
+ output = {
+ "state": state,
+ "embed": out,
+ "logit": self.classifier(out),
+ "attn_weight": attn_weight
+ }
+ return output
+
+
+class StyleBahAttnDecoder(RnnDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs):
+ """
+ concatenate fc, attn, word to feed to the rnn
+ """
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs)
+ attn_size = kwargs.get("attn_size", self.d_model)
+ self.model = getattr(nn, self.rnn_type)(
+ input_size=self.emb_dim * 3,
+ hidden_size=self.d_model,
+ batch_first=True,
+ num_layers=self.num_layers,
+ bidirectional=self.bidirectional)
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
+ self.d_model * (self.bidirectional + 1) * \
+ self.num_layers,
+ attn_size)
+ self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
+ self.apply(init)
+
+ def forward(self, input_dict):
+ word = input_dict["word"]
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
+ fc_emb = input_dict["fc_emb"]
+ attn_emb = input_dict["attn_emb"]
+ attn_emb_len = input_dict["attn_emb_len"]
+ style = input_dict["style"]
+
+ word = word.to(fc_emb.device)
+ embed = self.in_dropout(self.word_embedding(word))
+
+ # embed: [N, 1, embed_size]
+ if state is None:
+ state = self.init_hidden(word.size(0), fc_emb.device)
+ if self.rnn_type == "LSTM":
+ query = state[0].transpose(0, 1).flatten(1)
+ else:
+ query = state.transpose(0, 1).flatten(1)
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
+
+ p_ctx = self.ctx_proj(c)
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), style.unsqueeze(1)),
+ dim=-1)
+
+ out, state = self.model(rnn_input, state)
+
+ output = {
+ "state": state,
+ "embed": out,
+ "logit": self.classifier(out),
+ "attn_weight": attn_weight
+ }
+ return output
+
+
+class BahAttnDecoder3(RnnDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs):
+ """
+ concatenate fc, attn, word to feed to the rnn
+ """
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs)
+ attn_size = kwargs.get("attn_size", self.d_model)
+ self.model = getattr(nn, self.rnn_type)(
+ input_size=self.emb_dim + attn_emb_dim,
+ hidden_size=self.d_model,
+ batch_first=True,
+ num_layers=self.num_layers,
+ bidirectional=self.bidirectional)
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
+ self.d_model * (self.bidirectional + 1) * \
+ self.num_layers,
+ attn_size)
+ self.ctx_proj = lambda x: x
+ self.apply(init)
+
+ def forward(self, input_dict):
+ word = input_dict["word"]
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
+ fc_emb = input_dict["fc_emb"]
+ attn_emb = input_dict["attn_emb"]
+ attn_emb_len = input_dict["attn_emb_len"]
+
+ if word.size(-1) == self.fc_emb_dim: # fc_emb
+ embed = word.unsqueeze(1)
+ elif word.size(-1) == 1: # word
+ word = word.to(fc_emb.device)
+ embed = self.in_dropout(self.word_embedding(word))
+ else:
+ raise Exception(f"problem with word input size {word.size()}")
+
+ # embed: [N, 1, embed_size]
+ if state is None:
+ state = self.init_hidden(word.size(0), fc_emb.device)
+ if self.rnn_type == "LSTM":
+ query = state[0].transpose(0, 1).flatten(1)
+ else:
+ query = state.transpose(0, 1).flatten(1)
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
+
+ p_ctx = self.ctx_proj(c)
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1)), dim=-1)
+
+ out, state = self.model(rnn_input, state)
+
+ output = {
+ "state": state,
+ "embed": out,
+ "logit": self.classifier(out),
+ "attn_weight": attn_weight
+ }
+ return output
+
+
+class SpecificityBahAttnDecoder(RnnDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs):
+ """
+ concatenate fc, attn, word to feed to the rnn
+ """
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, d_model, **kwargs)
+ attn_size = kwargs.get("attn_size", self.d_model)
+ self.model = getattr(nn, self.rnn_type)(
+ input_size=self.emb_dim + attn_emb_dim + 1,
+ hidden_size=self.d_model,
+ batch_first=True,
+ num_layers=self.num_layers,
+ bidirectional=self.bidirectional)
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
+ self.d_model * (self.bidirectional + 1) * \
+ self.num_layers,
+ attn_size)
+ self.ctx_proj = lambda x: x
+ self.apply(init)
+
+ def forward(self, input_dict):
+ word = input_dict["word"]
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
+ fc_emb = input_dict["fc_emb"]
+ attn_emb = input_dict["attn_emb"]
+ attn_emb_len = input_dict["attn_emb_len"]
+ condition = input_dict["condition"] # [N,]
+
+ word = word.to(fc_emb.device)
+ embed = self.in_dropout(self.word_embedding(word))
+
+ # embed: [N, 1, embed_size]
+ if state is None:
+ state = self.init_hidden(word.size(0), fc_emb.device)
+ if self.rnn_type == "LSTM":
+ query = state[0].transpose(0, 1).flatten(1)
+ else:
+ query = state.transpose(0, 1).flatten(1)
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
+
+ p_ctx = self.ctx_proj(c)
+ rnn_input = torch.cat(
+ (embed, p_ctx.unsqueeze(1), condition.reshape(-1, 1, 1)),
+ dim=-1)
+
+ out, state = self.model(rnn_input, state)
+
+ output = {
+ "state": state,
+ "embed": out,
+ "logit": self.classifier(out),
+ "attn_weight": attn_weight
+ }
+ return output
+
+
+class TransformerDecoder(BaseDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, **kwargs):
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout=dropout,)
+ self.d_model = emb_dim
+ self.nhead = kwargs.get("nhead", self.d_model // 64)
+ self.nlayers = kwargs.get("nlayers", 2)
+ self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
+
+ self.pos_encoder = PositionalEncoding(self.d_model, dropout)
+ layer = nn.TransformerDecoderLayer(d_model=self.d_model,
+ nhead=self.nhead,
+ dim_feedforward=self.dim_feedforward,
+ dropout=dropout)
+ self.model = nn.TransformerDecoder(layer, self.nlayers)
+ self.classifier = nn.Linear(self.d_model, vocab_size)
+ self.attn_proj = nn.Sequential(
+ nn.Linear(self.attn_emb_dim, self.d_model),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.LayerNorm(self.d_model)
+ )
+ # self.attn_proj = lambda x: x
+ self.init_params()
+
+ def init_params(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def generate_square_subsequent_mask(self, max_length):
+ mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
+ return mask
+
+ def forward(self, input_dict):
+ word = input_dict["word"]
+ attn_emb = input_dict["attn_emb"]
+ attn_emb_len = input_dict["attn_emb_len"]
+ cap_padding_mask = input_dict["cap_padding_mask"]
+
+ p_attn_emb = self.attn_proj(attn_emb)
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
+ word = word.to(attn_emb.device)
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
+ embed = self.pos_encoder(embed)
+
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
+ tgt_key_padding_mask=cap_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask)
+ output = output.transpose(0, 1)
+ output = {
+ "embed": output,
+ "logit": self.classifier(output),
+ }
+ return output
+
+
+
+
+class EventTransformerDecoder(TransformerDecoder):
+
+ def forward(self, input_dict):
+ word = input_dict["word"] # index of word embeddings
+ attn_emb = input_dict["attn_emb"]
+ attn_emb_len = input_dict["attn_emb_len"]
+ cap_padding_mask = input_dict["cap_padding_mask"]
+ event_emb = input_dict["event"] # [N, emb_dim]
+
+ p_attn_emb = self.attn_proj(attn_emb)
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
+ word = word.to(attn_emb.device)
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
+
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
+ embed += event_emb
+ embed = self.pos_encoder(embed)
+
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
+ tgt_key_padding_mask=cap_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask)
+ output = output.transpose(0, 1)
+ output = {
+ "embed": output,
+ "logit": self.classifier(output),
+ }
+ return output
+
+
+class KeywordProbTransformerDecoder(TransformerDecoder):
+
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, keyword_classes_num, **kwargs):
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
+ dropout, **kwargs)
+ self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model)
+ self.word_keyword_norm = nn.LayerNorm(self.d_model)
+
+ def forward(self, input_dict):
+ word = input_dict["word"] # index of word embeddings
+ attn_emb = input_dict["attn_emb"]
+ attn_emb_len = input_dict["attn_emb_len"]
+ cap_padding_mask = input_dict["cap_padding_mask"]
+ keyword = input_dict["keyword"] # [N, keyword_classes_num]
+
+ p_attn_emb = self.attn_proj(attn_emb)
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
+ word = word.to(attn_emb.device)
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
+
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
+ embed += self.keyword_proj(keyword)
+ embed = self.word_keyword_norm(embed)
+
+ embed = self.pos_encoder(embed)
+
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
+ tgt_key_padding_mask=cap_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask)
+ output = output.transpose(0, 1)
+ output = {
+ "embed": output,
+ "logit": self.classifier(output),
+ }
+ return output
diff --git a/audio_to_text/captioning/models/encoder.py b/audio_to_text/captioning/models/encoder.py
new file mode 100644
index 0000000..0d6d8e8
--- /dev/null
+++ b/audio_to_text/captioning/models/encoder.py
@@ -0,0 +1,686 @@
+# -*- coding: utf-8 -*-
+
+import math
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchaudio import transforms
+from torchlibrosa.augmentation import SpecAugmentation
+
+from .utils import mean_with_lens, max_with_lens, \
+ init, pack_wrapper, generate_length_mask, PositionalEncoding
+
+
+def init_layer(layer):
+ """Initialize a Linear or Convolutional layer. """
+ nn.init.xavier_uniform_(layer.weight)
+
+ if hasattr(layer, 'bias'):
+ if layer.bias is not None:
+ layer.bias.data.fill_(0.)
+
+
+def init_bn(bn):
+ """Initialize a Batchnorm layer. """
+ bn.bias.data.fill_(0.)
+ bn.weight.data.fill_(1.)
+
+
+class BaseEncoder(nn.Module):
+
+ """
+ Encode the given audio into embedding
+ Base encoder class, cannot be called directly
+ All encoders should inherit from this class
+ """
+
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
+ super(BaseEncoder, self).__init__()
+ self.spec_dim = spec_dim
+ self.fc_feat_dim = fc_feat_dim
+ self.attn_feat_dim = attn_feat_dim
+
+
+ def forward(self, x):
+ #########################
+ # an encoder first encodes audio feature into embedding, obtaining
+ # `encoded`: {
+ # fc_embs: [N, fc_emb_dim],
+ # attn_embs: [N, attn_max_len, attn_emb_dim],
+ # attn_emb_lens: [N,]
+ # }
+ #########################
+ raise NotImplementedError
+
+
+class Block2D(nn.Module):
+
+ def __init__(self, cin, cout, kernel_size=3, padding=1):
+ super().__init__()
+ self.block = nn.Sequential(
+ nn.BatchNorm2d(cin),
+ nn.Conv2d(cin,
+ cout,
+ kernel_size=kernel_size,
+ padding=padding,
+ bias=False),
+ nn.LeakyReLU(inplace=True, negative_slope=0.1))
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class LinearSoftPool(nn.Module):
+ """LinearSoftPool
+ Linear softmax, takes logits and returns a probability, near to the actual maximum value.
+ Taken from the paper:
+ A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling
+ https://arxiv.org/abs/1810.09050
+ """
+ def __init__(self, pooldim=1):
+ super().__init__()
+ self.pooldim = pooldim
+
+ def forward(self, logits, time_decision):
+ return (time_decision**2).sum(self.pooldim) / time_decision.sum(
+ self.pooldim)
+
+
+class MeanPool(nn.Module):
+
+ def __init__(self, pooldim=1):
+ super().__init__()
+ self.pooldim = pooldim
+
+ def forward(self, logits, decision):
+ return torch.mean(decision, dim=self.pooldim)
+
+
+class AttentionPool(nn.Module):
+ """docstring for AttentionPool"""
+ def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs):
+ super().__init__()
+ self.inputdim = inputdim
+ self.outputdim = outputdim
+ self.pooldim = pooldim
+ self.transform = nn.Linear(inputdim, outputdim)
+ self.activ = nn.Softmax(dim=self.pooldim)
+ self.eps = 1e-7
+
+ def forward(self, logits, decision):
+ # Input is (B, T, D)
+ # B, T, D
+ w = self.activ(torch.clamp(self.transform(logits), -15, 15))
+ detect = (decision * w).sum(
+ self.pooldim) / (w.sum(self.pooldim) + self.eps)
+ # B, T, D
+ return detect
+
+
+class MMPool(nn.Module):
+
+ def __init__(self, dims):
+ super().__init__()
+ self.avgpool = nn.AvgPool2d(dims)
+ self.maxpool = nn.MaxPool2d(dims)
+
+ def forward(self, x):
+ return self.avgpool(x) + self.maxpool(x)
+
+
+def parse_poolingfunction(poolingfunction_name='mean', **kwargs):
+ """parse_poolingfunction
+ A heler function to parse any temporal pooling
+ Pooling is done on dimension 1
+ :param poolingfunction_name:
+ :param **kwargs:
+ """
+ poolingfunction_name = poolingfunction_name.lower()
+ if poolingfunction_name == 'mean':
+ return MeanPool(pooldim=1)
+ elif poolingfunction_name == 'linear':
+ return LinearSoftPool(pooldim=1)
+ elif poolingfunction_name == 'attention':
+ return AttentionPool(inputdim=kwargs['inputdim'],
+ outputdim=kwargs['outputdim'])
+
+
+def embedding_pooling(x, lens, pooling="mean"):
+ if pooling == "max":
+ fc_embs = max_with_lens(x, lens)
+ elif pooling == "mean":
+ fc_embs = mean_with_lens(x, lens)
+ elif pooling == "mean+max":
+ x_mean = mean_with_lens(x, lens)
+ x_max = max_with_lens(x, lens)
+ fc_embs = x_mean + x_max
+ elif pooling == "last":
+ indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
+ # indices: [N, 1, hidden]
+ fc_embs = torch.gather(x, 1, indices).squeeze(1)
+ else:
+ raise Exception(f"pooling method {pooling} not support")
+ return fc_embs
+
+
+class Cdur5Encoder(BaseEncoder):
+
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
+ super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
+ self.pooling = pooling
+ self.features = nn.Sequential(
+ Block2D(1, 32),
+ nn.LPPool2d(4, (2, 4)),
+ Block2D(32, 128),
+ Block2D(128, 128),
+ nn.LPPool2d(4, (2, 4)),
+ Block2D(128, 128),
+ Block2D(128, 128),
+ nn.LPPool2d(4, (1, 4)),
+ nn.Dropout(0.3),
+ )
+ with torch.no_grad():
+ rnn_input_dim = self.features(
+ torch.randn(1, 1, 500, spec_dim)).shape
+ rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1]
+
+ self.gru = nn.GRU(rnn_input_dim,
+ 128,
+ bidirectional=True,
+ batch_first=True)
+ self.apply(init)
+
+ def forward(self, input_dict):
+ x = input_dict["spec"]
+ lens = input_dict["spec_len"]
+ if "upsample" not in input_dict:
+ input_dict["upsample"] = False
+ lens = torch.as_tensor(copy.deepcopy(lens))
+ N, T, _ = x.shape
+ x = x.unsqueeze(1)
+ x = self.features(x)
+ x = x.transpose(1, 2).contiguous().flatten(-2)
+ x, _ = self.gru(x)
+ if input_dict["upsample"]:
+ x = nn.functional.interpolate(
+ x.transpose(1, 2),
+ T,
+ mode='linear',
+ align_corners=False).transpose(1, 2)
+ else:
+ lens //= 4
+ attn_emb = x
+ fc_emb = embedding_pooling(x, lens, self.pooling)
+ return {
+ "attn_emb": attn_emb,
+ "fc_emb": fc_emb,
+ "attn_emb_len": lens
+ }
+
+
+def conv_conv_block(in_channel, out_channel):
+ return nn.Sequential(
+ nn.Conv2d(in_channel,
+ out_channel,
+ kernel_size=3,
+ bias=False,
+ padding=1),
+ nn.BatchNorm2d(out_channel),
+ nn.ReLU(True),
+ nn.Conv2d(out_channel,
+ out_channel,
+ kernel_size=3,
+ bias=False,
+ padding=1),
+ nn.BatchNorm2d(out_channel),
+ nn.ReLU(True)
+ )
+
+
+class Cdur8Encoder(BaseEncoder):
+
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
+ super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
+ self.pooling = pooling
+ self.features = nn.Sequential(
+ conv_conv_block(1, 64),
+ MMPool((2, 2)),
+ nn.Dropout(0.2, True),
+ conv_conv_block(64, 128),
+ MMPool((2, 2)),
+ nn.Dropout(0.2, True),
+ conv_conv_block(128, 256),
+ MMPool((1, 2)),
+ nn.Dropout(0.2, True),
+ conv_conv_block(256, 512),
+ MMPool((1, 2)),
+ nn.Dropout(0.2, True),
+ nn.AdaptiveAvgPool2d((None, 1)),
+ )
+ self.init_bn = nn.BatchNorm2d(spec_dim)
+ self.embedding = nn.Linear(512, 512)
+ self.gru = nn.GRU(512, 256, bidirectional=True, batch_first=True)
+ self.apply(init)
+
+ def forward(self, input_dict):
+ x = input_dict["spec"]
+ lens = input_dict["spec_len"]
+ lens = torch.as_tensor(copy.deepcopy(lens))
+ x = x.unsqueeze(1) # B x 1 x T x D
+ x = x.transpose(1, 3)
+ x = self.init_bn(x)
+ x = x.transpose(1, 3)
+ x = self.features(x)
+ x = x.transpose(1, 2).contiguous().flatten(-2)
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.embedding(x))
+ x, _ = self.gru(x)
+ attn_emb = x
+ lens //= 4
+ fc_emb = embedding_pooling(x, lens, self.pooling)
+ return {
+ "attn_emb": attn_emb,
+ "fc_emb": fc_emb,
+ "attn_emb_len": lens
+ }
+
+
+class Cnn10Encoder(BaseEncoder):
+
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
+ super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
+ self.features = nn.Sequential(
+ conv_conv_block(1, 64),
+ nn.AvgPool2d((2, 2)),
+ nn.Dropout(0.2, True),
+ conv_conv_block(64, 128),
+ nn.AvgPool2d((2, 2)),
+ nn.Dropout(0.2, True),
+ conv_conv_block(128, 256),
+ nn.AvgPool2d((2, 2)),
+ nn.Dropout(0.2, True),
+ conv_conv_block(256, 512),
+ nn.AvgPool2d((2, 2)),
+ nn.Dropout(0.2, True),
+ nn.AdaptiveAvgPool2d((None, 1)),
+ )
+ self.init_bn = nn.BatchNorm2d(spec_dim)
+ self.embedding = nn.Linear(512, 512)
+ self.apply(init)
+
+ def forward(self, input_dict):
+ x = input_dict["spec"]
+ lens = input_dict["spec_len"]
+ lens = torch.as_tensor(copy.deepcopy(lens))
+ x = x.unsqueeze(1) # [N, 1, T, D]
+ x = x.transpose(1, 3)
+ x = self.init_bn(x)
+ x = x.transpose(1, 3)
+ x = self.features(x) # [N, 512, T/16, 1]
+ x = x.transpose(1, 2).contiguous().flatten(-2) # [N, T/16, 512]
+ attn_emb = x
+ lens //= 16
+ fc_emb = embedding_pooling(x, lens, "mean+max")
+ fc_emb = F.dropout(fc_emb, p=0.5, training=self.training)
+ fc_emb = self.embedding(fc_emb)
+ fc_emb = F.relu_(fc_emb)
+ return {
+ "attn_emb": attn_emb,
+ "fc_emb": fc_emb,
+ "attn_emb_len": lens
+ }
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels):
+
+ super(ConvBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3), stride=(1, 1),
+ padding=(1, 1), bias=False)
+
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3), stride=(1, 1),
+ padding=(1, 1), bias=False)
+
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_layer(self.conv1)
+ init_layer(self.conv2)
+ init_bn(self.bn1)
+ init_bn(self.bn2)
+
+
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
+
+ x = input
+ x = F.relu_(self.bn1(self.conv1(x)))
+ x = F.relu_(self.bn2(self.conv2(x)))
+ if pool_type == 'max':
+ x = F.max_pool2d(x, kernel_size=pool_size)
+ elif pool_type == 'avg':
+ x = F.avg_pool2d(x, kernel_size=pool_size)
+ elif pool_type == 'avg+max':
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
+ x = x1 + x2
+ else:
+ raise Exception('Incorrect argument!')
+
+ return x
+
+
+class Cnn14Encoder(nn.Module):
+ def __init__(self, sample_rate=32000):
+ super().__init__()
+ sr_to_fmax = {
+ 32000: 14000,
+ 16000: 8000
+ }
+ # Logmel spectrogram extractor
+ self.melspec_extractor = transforms.MelSpectrogram(
+ sample_rate=sample_rate,
+ n_fft=32 * sample_rate // 1000,
+ win_length=32 * sample_rate // 1000,
+ hop_length=10 * sample_rate // 1000,
+ f_min=50,
+ f_max=sr_to_fmax[sample_rate],
+ n_mels=64,
+ norm="slaney",
+ mel_scale="slaney"
+ )
+ self.hop_length = 10 * sample_rate // 1000
+ self.db_transform = transforms.AmplitudeToDB()
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64,
+ time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2)
+
+ self.bn0 = nn.BatchNorm2d(64)
+
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
+
+ self.downsample_ratio = 32
+
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_bn(self.bn0)
+ init_layer(self.fc1)
+
+ def load_pretrained(self, pretrained):
+ checkpoint = torch.load(pretrained, map_location="cpu")
+
+ if "model" in checkpoint:
+ state_keys = checkpoint["model"].keys()
+ backbone = False
+ for key in state_keys:
+ if key.startswith("backbone."):
+ backbone = True
+ break
+
+ if backbone: # COLA
+ state_dict = {}
+ for key, value in checkpoint["model"].items():
+ if key.startswith("backbone."):
+ model_key = key.replace("backbone.", "")
+ state_dict[model_key] = value
+ else: # PANNs
+ state_dict = checkpoint["model"]
+ elif "state_dict" in checkpoint: # CLAP
+ state_dict = checkpoint["state_dict"]
+ state_dict_keys = list(filter(
+ lambda x: "audio_encoder" in x, state_dict.keys()))
+ state_dict = {
+ key.replace('audio_encoder.', ''): state_dict[key]
+ for key in state_dict_keys
+ }
+ else:
+ raise Exception("Unkown checkpoint format")
+
+ model_dict = self.state_dict()
+ pretrained_dict = {
+ k: v for k, v in state_dict.items() if (k in model_dict) and (
+ model_dict[k].shape == v.shape)
+ }
+ model_dict.update(pretrained_dict)
+ self.load_state_dict(model_dict, strict=True)
+
+ def forward(self, input_dict):
+ """
+ Input: (batch_size, n_samples)"""
+ waveform = input_dict["wav"]
+ wave_length = input_dict["wav_len"]
+ specaug = input_dict["specaug"]
+ x = self.melspec_extractor(waveform)
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
+ x = x.transpose(1, 2)
+ x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
+
+ # SpecAugment
+ if self.training and specaug:
+ x = self.spec_augmenter(x)
+
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = torch.mean(x, dim=3)
+ attn_emb = x.transpose(1, 2)
+
+ wave_length = torch.as_tensor(wave_length)
+ feat_length = torch.div(wave_length, self.hop_length,
+ rounding_mode="floor") + 1
+ feat_length = torch.div(feat_length, self.downsample_ratio,
+ rounding_mode="floor")
+ x_max = max_with_lens(attn_emb, feat_length)
+ x_mean = mean_with_lens(attn_emb, feat_length)
+ x = x_max + x_mean
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.fc1(x))
+ fc_emb = F.dropout(x, p=0.5, training=self.training)
+
+ output_dict = {
+ 'fc_emb': fc_emb,
+ 'attn_emb': attn_emb,
+ 'attn_emb_len': feat_length
+ }
+
+ return output_dict
+
+
+class RnnEncoder(BaseEncoder):
+
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim,
+ pooling="mean", **kwargs):
+ super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
+ self.pooling = pooling
+ self.hidden_size = kwargs.get('hidden_size', 512)
+ self.bidirectional = kwargs.get('bidirectional', False)
+ self.num_layers = kwargs.get('num_layers', 1)
+ self.dropout = kwargs.get('dropout', 0.2)
+ self.rnn_type = kwargs.get('rnn_type', "GRU")
+ self.in_bn = kwargs.get('in_bn', False)
+ self.embed_dim = self.hidden_size * (self.bidirectional + 1)
+ self.network = getattr(nn, self.rnn_type)(
+ attn_feat_dim,
+ self.hidden_size,
+ num_layers=self.num_layers,
+ bidirectional=self.bidirectional,
+ dropout=self.dropout,
+ batch_first=True)
+ if self.in_bn:
+ self.bn = nn.BatchNorm1d(self.embed_dim)
+ self.apply(init)
+
+ def forward(self, input_dict):
+ x = input_dict["attn"]
+ lens = input_dict["attn_len"]
+ lens = torch.as_tensor(lens)
+ # x: [N, T, E]
+ if self.in_bn:
+ x = pack_wrapper(self.bn, x, lens)
+ out = pack_wrapper(self.network, x, lens)
+ # out: [N, T, hidden]
+ attn_emb = out
+ fc_emb = embedding_pooling(out, lens, self.pooling)
+ return {
+ "attn_emb": attn_emb,
+ "fc_emb": fc_emb,
+ "attn_emb_len": lens
+ }
+
+
+class Cnn14RnnEncoder(nn.Module):
+ def __init__(self, sample_rate=32000, pretrained=None,
+ freeze_cnn=False, freeze_cnn_bn=False,
+ pooling="mean", **kwargs):
+ super().__init__()
+ self.cnn = Cnn14Encoder(sample_rate)
+ self.rnn = RnnEncoder(64, 2048, 2048, pooling, **kwargs)
+ if pretrained is not None:
+ self.cnn.load_pretrained(pretrained)
+ if freeze_cnn:
+ assert pretrained is not None, "cnn is not pretrained but frozen"
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.freeze_cnn_bn = freeze_cnn_bn
+
+ def train(self, mode):
+ super().train(mode=mode)
+ if self.freeze_cnn_bn:
+ def bn_eval(module):
+ class_name = module.__class__.__name__
+ if class_name.find("BatchNorm") != -1:
+ module.eval()
+ self.cnn.apply(bn_eval)
+ return self
+
+ def forward(self, input_dict):
+ output_dict = self.cnn(input_dict)
+ output_dict["attn"] = output_dict["attn_emb"]
+ output_dict["attn_len"] = output_dict["attn_emb_len"]
+ del output_dict["attn_emb"], output_dict["attn_emb_len"]
+ output_dict = self.rnn(output_dict)
+ return output_dict
+
+
+class TransformerEncoder(BaseEncoder):
+
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, d_model, **kwargs):
+ super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
+ self.d_model = d_model
+ dropout = kwargs.get("dropout", 0.2)
+ self.nhead = kwargs.get("nhead", self.d_model // 64)
+ self.nlayers = kwargs.get("nlayers", 2)
+ self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
+
+ self.attn_proj = nn.Sequential(
+ nn.Linear(attn_feat_dim, self.d_model),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.LayerNorm(self.d_model)
+ )
+ layer = nn.TransformerEncoderLayer(d_model=self.d_model,
+ nhead=self.nhead,
+ dim_feedforward=self.dim_feedforward,
+ dropout=dropout)
+ self.model = nn.TransformerEncoder(layer, self.nlayers)
+ self.cls_token = nn.Parameter(torch.zeros(d_model))
+ self.init_params()
+
+ def init_params(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, input_dict):
+ attn_feat = input_dict["attn"]
+ attn_feat_len = input_dict["attn_len"]
+ attn_feat_len = torch.as_tensor(attn_feat_len)
+
+ attn_feat = self.attn_proj(attn_feat) # [bs, T, d_model]
+
+ cls_emb = self.cls_token.reshape(1, 1, self.d_model).repeat(
+ attn_feat.size(0), 1, 1)
+ attn_feat = torch.cat((cls_emb, attn_feat), dim=1)
+ attn_feat = attn_feat.transpose(0, 1)
+
+ attn_feat_len += 1
+ src_key_padding_mask = ~generate_length_mask(
+ attn_feat_len, attn_feat.size(0)).to(attn_feat.device)
+ output = self.model(attn_feat, src_key_padding_mask=src_key_padding_mask)
+
+ attn_emb = output.transpose(0, 1)
+ fc_emb = attn_emb[:, 0]
+ return {
+ "attn_emb": attn_emb,
+ "fc_emb": fc_emb,
+ "attn_emb_len": attn_feat_len
+ }
+
+
+class Cnn14TransformerEncoder(nn.Module):
+ def __init__(self, sample_rate=32000, pretrained=None,
+ freeze_cnn=False, freeze_cnn_bn=False,
+ d_model="mean", **kwargs):
+ super().__init__()
+ self.cnn = Cnn14Encoder(sample_rate)
+ self.trm = TransformerEncoder(64, 2048, 2048, d_model, **kwargs)
+ if pretrained is not None:
+ self.cnn.load_pretrained(pretrained)
+ if freeze_cnn:
+ assert pretrained is not None, "cnn is not pretrained but frozen"
+ for param in self.cnn.parameters():
+ param.requires_grad = False
+ self.freeze_cnn_bn = freeze_cnn_bn
+
+ def train(self, mode):
+ super().train(mode=mode)
+ if self.freeze_cnn_bn:
+ def bn_eval(module):
+ class_name = module.__class__.__name__
+ if class_name.find("BatchNorm") != -1:
+ module.eval()
+ self.cnn.apply(bn_eval)
+ return self
+
+ def forward(self, input_dict):
+ output_dict = self.cnn(input_dict)
+ output_dict["attn"] = output_dict["attn_emb"]
+ output_dict["attn_len"] = output_dict["attn_emb_len"]
+ del output_dict["attn_emb"], output_dict["attn_emb_len"]
+ output_dict = self.trm(output_dict)
+ return output_dict
+
+
+
+
+
diff --git a/audio_to_text/captioning/models/transformer_model.py b/audio_to_text/captioning/models/transformer_model.py
new file mode 100644
index 0000000..76c97f1
--- /dev/null
+++ b/audio_to_text/captioning/models/transformer_model.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+import random
+import torch
+import torch.nn as nn
+
+from .base_model import CaptionModel
+from .utils import repeat_tensor
+import audio_to_text.captioning.models.decoder
+
+
+class TransformerModel(CaptionModel):
+
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
+ if not hasattr(self, "compatible_decoders"):
+ self.compatible_decoders = (
+ audio_to_text.captioning.models.decoder.TransformerDecoder,
+ )
+ super().__init__(encoder, decoder, **kwargs)
+
+ def seq_forward(self, input_dict):
+ cap = input_dict["cap"]
+ cap_padding_mask = (cap == self.pad_idx).to(cap.device)
+ cap_padding_mask = cap_padding_mask[:, :-1]
+ output = self.decoder(
+ {
+ "word": cap[:, :-1],
+ "attn_emb": input_dict["attn_emb"],
+ "attn_emb_len": input_dict["attn_emb_len"],
+ "cap_padding_mask": cap_padding_mask
+ }
+ )
+ return output
+
+ def prepare_decoder_input(self, input_dict, output):
+ decoder_input = {
+ "attn_emb": input_dict["attn_emb"],
+ "attn_emb_len": input_dict["attn_emb_len"]
+ }
+ t = input_dict["t"]
+
+ ###############
+ # determine input word
+ ################
+ if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
+ word = input_dict["cap"][:, :t+1]
+ else:
+ start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
+ if t == 0:
+ word = start_word
+ else:
+ word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
+ # word: [N, T]
+ decoder_input["word"] = word
+
+ cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
+ decoder_input["cap_padding_mask"] = cap_padding_mask
+ return decoder_input
+
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
+ decoder_input = {}
+ t = input_dict["t"]
+ i = input_dict["sample_idx"]
+ beam_size = input_dict["beam_size"]
+ ###############
+ # prepare attn embeds
+ ################
+ if t == 0:
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
+ attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
+ output_i["attn_emb"] = attn_emb
+ output_i["attn_emb_len"] = attn_emb_len
+ decoder_input["attn_emb"] = output_i["attn_emb"]
+ decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
+ ###############
+ # determine input word
+ ################
+ start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
+ if t == 0:
+ word = start_word
+ else:
+ word = torch.cat((start_word, output_i["seq"]), dim=-1)
+ decoder_input["word"] = word
+ cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
+ decoder_input["cap_padding_mask"] = cap_padding_mask
+
+ return decoder_input
+
+
+class M2TransformerModel(CaptionModel):
+
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
+ if not hasattr(self, "compatible_decoders"):
+ self.compatible_decoders = (
+ captioning.models.decoder.M2TransformerDecoder,
+ )
+ super().__init__(encoder, decoder, **kwargs)
+ self.check_encoder_compatibility()
+
+ def check_encoder_compatibility(self):
+ assert isinstance(self.encoder, captioning.models.encoder.M2TransformerEncoder), \
+ f"only M2TransformerModel is compatible with {self.__class__.__name__}"
+
+
+ def seq_forward(self, input_dict):
+ cap = input_dict["cap"]
+ output = self.decoder(
+ {
+ "word": cap[:, :-1],
+ "attn_emb": input_dict["attn_emb"],
+ "attn_emb_mask": input_dict["attn_emb_mask"],
+ }
+ )
+ return output
+
+ def prepare_decoder_input(self, input_dict, output):
+ decoder_input = {
+ "attn_emb": input_dict["attn_emb"],
+ "attn_emb_mask": input_dict["attn_emb_mask"]
+ }
+ t = input_dict["t"]
+
+ ###############
+ # determine input word
+ ################
+ if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
+ word = input_dict["cap"][:, :t+1]
+ else:
+ start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
+ if t == 0:
+ word = start_word
+ else:
+ word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
+ # word: [N, T]
+ decoder_input["word"] = word
+
+ return decoder_input
+
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
+ decoder_input = {}
+ t = input_dict["t"]
+ i = input_dict["sample_idx"]
+ beam_size = input_dict["beam_size"]
+ ###############
+ # prepare attn embeds
+ ################
+ if t == 0:
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
+ attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size)
+ output_i["attn_emb"] = attn_emb
+ output_i["attn_emb_mask"] = attn_emb_mask
+ decoder_input["attn_emb"] = output_i["attn_emb"]
+ decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"]
+ ###############
+ # determine input word
+ ################
+ start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
+ if t == 0:
+ word = start_word
+ else:
+ word = torch.cat((start_word, output_i["seq"]), dim=-1)
+ decoder_input["word"] = word
+
+ return decoder_input
+
+
+class EventEncoder(nn.Module):
+ """
+ Encode the Label information in AudioCaps and AudioSet
+ """
+ def __init__(self, emb_dim, vocab_size=527):
+ super(EventEncoder, self).__init__()
+ self.label_embedding = nn.Parameter(
+ torch.randn((vocab_size, emb_dim)), requires_grad=True)
+
+ def forward(self, word_idxs):
+ indices = word_idxs / word_idxs.sum(dim=1, keepdim=True)
+ embeddings = indices @ self.label_embedding
+ return embeddings
+
+
+class EventCondTransformerModel(TransformerModel):
+
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
+ if not hasattr(self, "compatible_decoders"):
+ self.compatible_decoders = (
+ captioning.models.decoder.EventTransformerDecoder,
+ )
+ super().__init__(encoder, decoder, **kwargs)
+ self.label_encoder = EventEncoder(decoder.emb_dim, 527)
+ self.train_forward_keys += ["events"]
+ self.inference_forward_keys += ["events"]
+
+ # def seq_forward(self, input_dict):
+ # cap = input_dict["cap"]
+ # cap_padding_mask = (cap == self.pad_idx).to(cap.device)
+ # cap_padding_mask = cap_padding_mask[:, :-1]
+ # output = self.decoder(
+ # {
+ # "word": cap[:, :-1],
+ # "attn_emb": input_dict["attn_emb"],
+ # "attn_emb_len": input_dict["attn_emb_len"],
+ # "cap_padding_mask": cap_padding_mask
+ # }
+ # )
+ # return output
+
+ def prepare_decoder_input(self, input_dict, output):
+ decoder_input = super().prepare_decoder_input(input_dict, output)
+ decoder_input["events"] = self.label_encoder(input_dict["events"])
+ return decoder_input
+
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
+ decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
+ t = input_dict["t"]
+ i = input_dict["sample_idx"]
+ beam_size = input_dict["beam_size"]
+ if t == 0:
+ output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size)
+ decoder_input["events"] = output_i["events"]
+ return decoder_input
+
+
+class KeywordCondTransformerModel(TransformerModel):
+
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
+ if not hasattr(self, "compatible_decoders"):
+ self.compatible_decoders = (
+ captioning.models.decoder.KeywordProbTransformerDecoder,
+ )
+ super().__init__(encoder, decoder, **kwargs)
+ self.train_forward_keys += ["keyword"]
+ self.inference_forward_keys += ["keyword"]
+
+ def seq_forward(self, input_dict):
+ cap = input_dict["cap"]
+ cap_padding_mask = (cap == self.pad_idx).to(cap.device)
+ cap_padding_mask = cap_padding_mask[:, :-1]
+ keyword = input_dict["keyword"]
+ output = self.decoder(
+ {
+ "word": cap[:, :-1],
+ "attn_emb": input_dict["attn_emb"],
+ "attn_emb_len": input_dict["attn_emb_len"],
+ "keyword": keyword,
+ "cap_padding_mask": cap_padding_mask
+ }
+ )
+ return output
+
+ def prepare_decoder_input(self, input_dict, output):
+ decoder_input = super().prepare_decoder_input(input_dict, output)
+ decoder_input["keyword"] = input_dict["keyword"]
+ return decoder_input
+
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
+ decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
+ t = input_dict["t"]
+ i = input_dict["sample_idx"]
+ beam_size = input_dict["beam_size"]
+ if t == 0:
+ output_i["keyword"] = repeat_tensor(input_dict["keyword"][i],
+ beam_size)
+ decoder_input["keyword"] = output_i["keyword"]
+ return decoder_input
+
diff --git a/audio_to_text/captioning/models/utils.py b/audio_to_text/captioning/models/utils.py
new file mode 100644
index 0000000..3623cf4
--- /dev/null
+++ b/audio_to_text/captioning/models/utils.py
@@ -0,0 +1,132 @@
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
+
+
+def sort_pack_padded_sequence(input, lengths):
+ sorted_lengths, indices = torch.sort(lengths, descending=True)
+ tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
+ inv_ix = indices.clone()
+ inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
+ return tmp, inv_ix
+
+def pad_unsort_packed_sequence(input, inv_ix):
+ tmp, _ = pad_packed_sequence(input, batch_first=True)
+ tmp = tmp[inv_ix]
+ return tmp
+
+def pack_wrapper(module, attn_feats, attn_feat_lens):
+ packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
+ if isinstance(module, torch.nn.RNNBase):
+ return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
+ else:
+ return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
+
+def generate_length_mask(lens, max_length=None):
+ lens = torch.as_tensor(lens)
+ N = lens.size(0)
+ if max_length is None:
+ max_length = max(lens)
+ idxs = torch.arange(max_length).repeat(N).view(N, max_length)
+ idxs = idxs.to(lens.device)
+ mask = (idxs < lens.view(-1, 1))
+ return mask
+
+def mean_with_lens(features, lens):
+ """
+ features: [N, T, ...] (assume the second dimension represents length)
+ lens: [N,]
+ """
+ lens = torch.as_tensor(lens)
+ if max(lens) != features.size(1):
+ max_length = features.size(1)
+ mask = generate_length_mask(lens, max_length)
+ else:
+ mask = generate_length_mask(lens)
+ mask = mask.to(features.device) # [N, T]
+
+ while mask.ndim < features.ndim:
+ mask = mask.unsqueeze(-1)
+ feature_mean = features * mask
+ feature_mean = feature_mean.sum(1)
+ while lens.ndim < feature_mean.ndim:
+ lens = lens.unsqueeze(1)
+ feature_mean = feature_mean / lens.to(features.device)
+ # feature_mean = features * mask.unsqueeze(-1)
+ # feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
+ return feature_mean
+
+def max_with_lens(features, lens):
+ """
+ features: [N, T, ...] (assume the second dimension represents length)
+ lens: [N,]
+ """
+ lens = torch.as_tensor(lens)
+ mask = generate_length_mask(lens).to(features.device) # [N, T]
+
+ feature_max = features.clone()
+ feature_max[~mask] = float("-inf")
+ feature_max, _ = feature_max.max(1)
+ return feature_max
+
+def repeat_tensor(x, n):
+ return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
+
+def init(m, method="kaiming"):
+ if isinstance(m, (nn.Conv2d, nn.Conv1d)):
+ if method == "kaiming":
+ nn.init.kaiming_uniform_(m.weight)
+ elif method == "xavier":
+ nn.init.xavier_uniform_(m.weight)
+ else:
+ raise Exception(f"initialization method {method} not supported")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ if method == "kaiming":
+ nn.init.kaiming_uniform_(m.weight)
+ elif method == "xavier":
+ nn.init.xavier_uniform_(m.weight)
+ else:
+ raise Exception(f"initialization method {method} not supported")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Embedding):
+ if method == "kaiming":
+ nn.init.kaiming_uniform_(m.weight)
+ elif method == "xavier":
+ nn.init.xavier_uniform_(m.weight)
+ else:
+ raise Exception(f"initialization method {method} not supported")
+
+
+
+
+class PositionalEncoding(nn.Module):
+
+ def __init__(self, d_model, dropout=0.1, max_len=100):
+ super(PositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
+ (-math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0).transpose(0, 1)
+ # self.register_buffer("pe", pe)
+ self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
+
+ def forward(self, x):
+ # x: [T, N, E]
+ x = x + self.pe[:x.size(0), :]
+ return self.dropout(x)
diff --git a/audio_to_text/captioning/utils/README.md b/audio_to_text/captioning/utils/README.md
new file mode 100644
index 0000000..c6fd17d
--- /dev/null
+++ b/audio_to_text/captioning/utils/README.md
@@ -0,0 +1,19 @@
+# Utils
+
+Scripts in this directory are used as utility functions.
+
+## BERT Pretrained Embeddings
+
+You can load pretrained word embeddings in Google [BERT](https://github.com/google-research/bert#pre-trained-models) instead of training word embeddings from scratch. The scripts in `utils/bert` need a BERT server in the background. We use BERT server from [bert-as-service](https://github.com/hanxiao/bert-as-service).
+
+To use bert-as-service, you need to first install the repository. It is recommended that you create a new environment with Tensorflow 1.3 to run BERT server since it is incompatible with Tensorflow 2.x.
+
+After successful installation of [bert-as-service](https://github.com/hanxiao/bert-as-service), downloading and running the BERT server needs to execute:
+
+```bash
+bash scripts/prepare_bert_server.sh zh
+```
+
+By default, server based on BERT base Chinese model is running in the background. You can change to other models by changing corresponding model name and path in `scripts/prepare_bert_server.sh`.
+
+To extract BERT word embeddings, you need to execute `utils/bert/create_word_embedding.py`.
diff --git a/audio_to_text/captioning/utils/__init__.py b/audio_to_text/captioning/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/audio_to_text/captioning/utils/bert/create_sent_embedding.py b/audio_to_text/captioning/utils/bert/create_sent_embedding.py
new file mode 100644
index 0000000..b517a32
--- /dev/null
+++ b/audio_to_text/captioning/utils/bert/create_sent_embedding.py
@@ -0,0 +1,89 @@
+import pickle
+import fire
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+
+class EmbeddingExtractor(object):
+
+ def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False):
+ from sentence_transformers import SentenceTransformer
+ lang2model = {
+ "zh": "distiluse-base-multilingual-cased",
+ "en": "bert-base-nli-mean-tokens"
+ }
+ lang = "zh" if zh else "en"
+ model = SentenceTransformer(lang2model[lang])
+
+ self.extract(caption_file, model, output, dev)
+
+ def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"):
+ from bert_serving.client import BertClient
+ client = BertClient(ip)
+
+ self.extract(caption_file, client, output, dev)
+
+ def extract(self, caption_file: str, model, output, dev: bool):
+ caption_df = pd.read_json(caption_file, dtype={"key": str})
+ embeddings = {}
+
+ if dev:
+ with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
+ for idx, row in caption_df.iterrows():
+ caption = row["caption"]
+ key = row["key"]
+ cap_idx = row["caption_index"]
+ embedding = model.encode([caption])
+ embedding = np.array(embedding).reshape(-1)
+ embeddings[f"{key}_{cap_idx}"] = embedding
+ pbar.update()
+
+ else:
+ dump = {}
+
+ with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
+ for idx, row in caption_df.iterrows():
+ key = row["key"]
+ caption = row["caption"]
+ value = np.array(model.encode([caption])).reshape(-1)
+
+ if key not in embeddings.keys():
+ embeddings[key] = [value]
+ else:
+ embeddings[key].append(value)
+
+ pbar.update()
+
+ for key in embeddings:
+ dump[key] = np.stack(embeddings[key])
+
+ embeddings = dump
+
+ with open(output, "wb") as f:
+ pickle.dump(embeddings, f)
+
+ def extract_sbert(self,
+ input_json: str,
+ output: str):
+ from sentence_transformers import SentenceTransformer
+ import json
+ import torch
+ from h5py import File
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
+ model = model.to(device)
+ model.eval()
+
+ data = json.load(open(input_json))["audios"]
+ with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store:
+ for sample in data:
+ audio_id = sample["audio_id"]
+ for cap in sample["captions"]:
+ cap_id = cap["cap_id"]
+ store[f"{audio_id}_{cap_id}"] = model.encode(cap["caption"])
+ pbar.update()
+
+
+if __name__ == "__main__":
+ fire.Fire(EmbeddingExtractor)
diff --git a/audio_to_text/captioning/utils/bert/create_word_embedding.py b/audio_to_text/captioning/utils/bert/create_word_embedding.py
new file mode 100644
index 0000000..43c980e
--- /dev/null
+++ b/audio_to_text/captioning/utils/bert/create_word_embedding.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+
+import sys
+import os
+
+from bert_serving.client import BertClient
+import numpy as np
+from tqdm import tqdm
+import fire
+import torch
+
+sys.path.append(os.getcwd())
+from utils.build_vocab import Vocabulary
+
+def main(vocab_file: str, output: str, server_hostname: str):
+ client = BertClient(ip=server_hostname)
+ vocabulary = torch.load(vocab_file)
+ vocab_size = len(vocabulary)
+
+ fake_embedding = client.encode(["test"]).reshape(-1)
+ embed_size = fake_embedding.shape[0]
+
+ print("Encoding words into embeddings with size: ", embed_size)
+
+ embeddings = np.empty((vocab_size, embed_size))
+ for i in tqdm(range(len(embeddings)), ascii=True):
+ embeddings[i] = client.encode([vocabulary.idx2word[i]])
+ np.save(output, embeddings)
+
+
+if __name__ == '__main__':
+ fire.Fire(main)
+
+
diff --git a/audio_to_text/captioning/utils/build_vocab.py b/audio_to_text/captioning/utils/build_vocab.py
new file mode 100644
index 0000000..e9fab23
--- /dev/null
+++ b/audio_to_text/captioning/utils/build_vocab.py
@@ -0,0 +1,153 @@
+import json
+from tqdm import tqdm
+import logging
+import pickle
+from collections import Counter
+import re
+import fire
+
+
+class Vocabulary(object):
+ """Simple vocabulary wrapper."""
+ def __init__(self):
+ self.word2idx = {}
+ self.idx2word = {}
+ self.idx = 0
+
+ def add_word(self, word):
+ if not word in self.word2idx:
+ self.word2idx[word] = self.idx
+ self.idx2word[self.idx] = word
+ self.idx += 1
+
+ def __call__(self, word):
+ if not word in self.word2idx:
+ return self.word2idx[""]
+ return self.word2idx[word]
+
+ def __getitem__(self, word_id):
+ return self.idx2word[word_id]
+
+ def __len__(self):
+ return len(self.word2idx)
+
+
+def build_vocab(input_json: str,
+ threshold: int,
+ keep_punctuation: bool,
+ host_address: str,
+ character_level: bool = False,
+ zh: bool = True ):
+ """Build vocabulary from csv file with a given threshold to drop all counts < threshold
+
+ Args:
+ input_json(string): Preprossessed json file. Structure like this:
+ {
+ 'audios': [
+ {
+ 'audio_id': 'xxx',
+ 'captions': [
+ {
+ 'caption': 'xxx',
+ 'cap_id': 'xxx'
+ }
+ ]
+ },
+ ...
+ ]
+ }
+ threshold (int): Threshold to drop all words with counts < threshold
+ keep_punctuation (bool): Includes or excludes punctuation.
+
+ Returns:
+ vocab (Vocab): Object with the processed vocabulary
+"""
+ data = json.load(open(input_json, "r"))["audios"]
+ counter = Counter()
+ pretokenized = "tokens" in data[0]["captions"][0]
+
+ if zh:
+ from nltk.parse.corenlp import CoreNLPParser
+ from zhon.hanzi import punctuation
+ if not pretokenized:
+ parser = CoreNLPParser(host_address)
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ if pretokenized:
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
+ else:
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
+ # Remove all punctuations
+ if not keep_punctuation:
+ caption = re.sub("[{}]".format(punctuation), "", caption)
+ if character_level:
+ tokens = list(caption)
+ else:
+ tokens = list(parser.tokenize(caption))
+ data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
+ counter.update(tokens)
+ else:
+ if pretokenized:
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
+ counter.update(tokens)
+ else:
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
+ captions = {}
+ for audio_idx in range(len(data)):
+ audio_id = data[audio_idx]["audio_id"]
+ captions[audio_id] = []
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
+ captions[audio_id].append({
+ "audio_id": audio_id,
+ "id": cap_idx,
+ "caption": caption
+ })
+ tokenizer = PTBTokenizer()
+ captions = tokenizer.tokenize(captions)
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ audio_id = data[audio_idx]["audio_id"]
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ tokens = captions[audio_id][cap_idx]
+ data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
+ counter.update(tokens.split(" "))
+
+ if not pretokenized:
+ json.dump({ "audios": data }, open(input_json, "w"), indent=4, ensure_ascii=not zh)
+ words = [word for word, cnt in counter.items() if cnt >= threshold]
+
+ # Create a vocab wrapper and add some special tokens.
+ vocab = Vocabulary()
+ vocab.add_word("")
+ vocab.add_word("")
+ vocab.add_word("")
+ vocab.add_word("")
+
+ # Add the words to the vocabulary.
+ for word in words:
+ vocab.add_word(word)
+ return vocab
+
+
+def process(input_json: str,
+ output_file: str,
+ threshold: int = 1,
+ keep_punctuation: bool = False,
+ character_level: bool = False,
+ host_address: str = "http://localhost:9000",
+ zh: bool = False):
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+ logging.basicConfig(level=logging.INFO, format=logfmt)
+ logging.info("Build Vocab")
+ vocabulary = build_vocab(
+ input_json=input_json, threshold=threshold, keep_punctuation=keep_punctuation,
+ host_address=host_address, character_level=character_level, zh=zh)
+ pickle.dump(vocabulary, open(output_file, "wb"))
+ logging.info("Total vocabulary size: {}".format(len(vocabulary)))
+ logging.info("Saved vocab to '{}'".format(output_file))
+
+
+if __name__ == '__main__':
+ fire.Fire(process)
diff --git a/audio_to_text/captioning/utils/build_vocab_ltp.py b/audio_to_text/captioning/utils/build_vocab_ltp.py
new file mode 100644
index 0000000..aae0c71
--- /dev/null
+++ b/audio_to_text/captioning/utils/build_vocab_ltp.py
@@ -0,0 +1,150 @@
+import json
+from tqdm import tqdm
+import logging
+import pickle
+from collections import Counter
+import re
+import fire
+
+class Vocabulary(object):
+ """Simple vocabulary wrapper."""
+ def __init__(self):
+ self.word2idx = {}
+ self.idx2word = {}
+ self.idx = 0
+
+ def add_word(self, word):
+ if not word in self.word2idx:
+ self.word2idx[word] = self.idx
+ self.idx2word[self.idx] = word
+ self.idx += 1
+
+ def __call__(self, word):
+ if not word in self.word2idx:
+ return self.word2idx[""]
+ return self.word2idx[word]
+
+ def __len__(self):
+ return len(self.word2idx)
+
+def build_vocab(input_json: str,
+ output_json: str,
+ threshold: int,
+ keep_punctuation: bool,
+ character_level: bool = False,
+ zh: bool = True ):
+ """Build vocabulary from csv file with a given threshold to drop all counts < threshold
+
+ Args:
+ input_json(string): Preprossessed json file. Structure like this:
+ {
+ 'audios': [
+ {
+ 'audio_id': 'xxx',
+ 'captions': [
+ {
+ 'caption': 'xxx',
+ 'cap_id': 'xxx'
+ }
+ ]
+ },
+ ...
+ ]
+ }
+ threshold (int): Threshold to drop all words with counts < threshold
+ keep_punctuation (bool): Includes or excludes punctuation.
+
+ Returns:
+ vocab (Vocab): Object with the processed vocabulary
+"""
+ data = json.load(open(input_json, "r"))["audios"]
+ counter = Counter()
+ pretokenized = "tokens" in data[0]["captions"][0]
+
+ if zh:
+ from ltp import LTP
+ from zhon.hanzi import punctuation
+ if not pretokenized:
+ parser = LTP("base")
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ if pretokenized:
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
+ else:
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
+ if character_level:
+ tokens = list(caption)
+ else:
+ tokens, _ = parser.seg([caption])
+ tokens = tokens[0]
+ # Remove all punctuations
+ if not keep_punctuation:
+ tokens = [token for token in tokens if token not in punctuation]
+ data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
+ counter.update(tokens)
+ else:
+ if pretokenized:
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
+ counter.update(tokens)
+ else:
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
+ captions = {}
+ for audio_idx in range(len(data)):
+ audio_id = data[audio_idx]["audio_id"]
+ captions[audio_id] = []
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
+ captions[audio_id].append({
+ "audio_id": audio_id,
+ "id": cap_idx,
+ "caption": caption
+ })
+ tokenizer = PTBTokenizer()
+ captions = tokenizer.tokenize(captions)
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ audio_id = data[audio_idx]["audio_id"]
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ tokens = captions[audio_id][cap_idx]
+ data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
+ counter.update(tokens.split(" "))
+
+ if not pretokenized:
+ if output_json is None:
+ output_json = input_json
+ json.dump({ "audios": data }, open(output_json, "w"), indent=4, ensure_ascii=not zh)
+ words = [word for word, cnt in counter.items() if cnt >= threshold]
+
+ # Create a vocab wrapper and add some special tokens.
+ vocab = Vocabulary()
+ vocab.add_word("")
+ vocab.add_word("")
+ vocab.add_word("")
+ vocab.add_word("")
+
+ # Add the words to the vocabulary.
+ for word in words:
+ vocab.add_word(word)
+ return vocab
+
+def process(input_json: str,
+ output_file: str,
+ output_json: str = None,
+ threshold: int = 1,
+ keep_punctuation: bool = False,
+ character_level: bool = False,
+ zh: bool = True):
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+ logging.basicConfig(level=logging.INFO, format=logfmt)
+ logging.info("Build Vocab")
+ vocabulary = build_vocab(
+ input_json=input_json, output_json=output_json, threshold=threshold,
+ keep_punctuation=keep_punctuation, character_level=character_level, zh=zh)
+ pickle.dump(vocabulary, open(output_file, "wb"))
+ logging.info("Total vocabulary size: {}".format(len(vocabulary)))
+ logging.info("Saved vocab to '{}'".format(output_file))
+
+
+if __name__ == '__main__':
+ fire.Fire(process)
diff --git a/audio_to_text/captioning/utils/build_vocab_spacy.py b/audio_to_text/captioning/utils/build_vocab_spacy.py
new file mode 100644
index 0000000..84da679
--- /dev/null
+++ b/audio_to_text/captioning/utils/build_vocab_spacy.py
@@ -0,0 +1,152 @@
+import json
+from tqdm import tqdm
+import logging
+import pickle
+from collections import Counter
+import re
+import fire
+
+class Vocabulary(object):
+ """Simple vocabulary wrapper."""
+ def __init__(self):
+ self.word2idx = {}
+ self.idx2word = {}
+ self.idx = 0
+
+ def add_word(self, word):
+ if not word in self.word2idx:
+ self.word2idx[word] = self.idx
+ self.idx2word[self.idx] = word
+ self.idx += 1
+
+ def __call__(self, word):
+ if not word in self.word2idx:
+ return self.word2idx[""]
+ return self.word2idx[word]
+
+ def __len__(self):
+ return len(self.word2idx)
+
+
+def build_vocab(input_json: str,
+ output_json: str,
+ threshold: int,
+ keep_punctuation: bool,
+ host_address: str,
+ character_level: bool = False,
+ retokenize: bool = True,
+ zh: bool = True ):
+ """Build vocabulary from csv file with a given threshold to drop all counts < threshold
+
+ Args:
+ input_json(string): Preprossessed json file. Structure like this:
+ {
+ 'audios': [
+ {
+ 'audio_id': 'xxx',
+ 'captions': [
+ {
+ 'caption': 'xxx',
+ 'cap_id': 'xxx'
+ }
+ ]
+ },
+ ...
+ ]
+ }
+ threshold (int): Threshold to drop all words with counts < threshold
+ keep_punctuation (bool): Includes or excludes punctuation.
+
+ Returns:
+ vocab (Vocab): Object with the processed vocabulary
+"""
+ data = json.load(open(input_json, "r"))["audios"]
+ counter = Counter()
+ if retokenize:
+ pretokenized = False
+ else:
+ pretokenized = "tokens" in data[0]["captions"][0]
+
+ if zh:
+ from nltk.parse.corenlp import CoreNLPParser
+ from zhon.hanzi import punctuation
+ if not pretokenized:
+ parser = CoreNLPParser(host_address)
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ if pretokenized:
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
+ else:
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
+ # Remove all punctuations
+ if not keep_punctuation:
+ caption = re.sub("[{}]".format(punctuation), "", caption)
+ if character_level:
+ tokens = list(caption)
+ else:
+ tokens = list(parser.tokenize(caption))
+ data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
+ counter.update(tokens)
+ else:
+ if pretokenized:
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
+ counter.update(tokens)
+ else:
+ import spacy
+ tokenizer = spacy.load("en_core_web_sm", disable=["parser", "ner"])
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ captions = data[audio_idx]["captions"]
+ for cap_idx in range(len(captions)):
+ caption = captions[cap_idx]["caption"]
+ doc = tokenizer(caption)
+ tokens = " ".join([str(token).lower() for token in doc])
+ data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
+ counter.update(tokens.split(" "))
+
+ if not pretokenized:
+ if output_json is None:
+ json.dump({ "audios": data }, open(input_json, "w"),
+ indent=4, ensure_ascii=not zh)
+ else:
+ json.dump({ "audios": data }, open(output_json, "w"),
+ indent=4, ensure_ascii=not zh)
+
+ words = [word for word, cnt in counter.items() if cnt >= threshold]
+
+ # Create a vocab wrapper and add some special tokens.
+ vocab = Vocabulary()
+ vocab.add_word("")
+ vocab.add_word("")
+ vocab.add_word("")
+ vocab.add_word("")
+
+ # Add the words to the vocabulary.
+ for word in words:
+ vocab.add_word(word)
+ return vocab
+
+def process(input_json: str,
+ output_file: str,
+ output_json: str = None,
+ threshold: int = 1,
+ keep_punctuation: bool = False,
+ character_level: bool = False,
+ retokenize: bool = False,
+ host_address: str = "http://localhost:9000",
+ zh: bool = True):
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+ logging.basicConfig(level=logging.INFO, format=logfmt)
+ logging.info("Build Vocab")
+ vocabulary = build_vocab(
+ input_json=input_json, output_json=output_json, threshold=threshold,
+ keep_punctuation=keep_punctuation, host_address=host_address,
+ character_level=character_level, retokenize=retokenize, zh=zh)
+ pickle.dump(vocabulary, open(output_file, "wb"))
+ logging.info("Total vocabulary size: {}".format(len(vocabulary)))
+ logging.info("Saved vocab to '{}'".format(output_file))
+
+
+if __name__ == '__main__':
+ fire.Fire(process)
diff --git a/audio_to_text/captioning/utils/eval_round_robin.py b/audio_to_text/captioning/utils/eval_round_robin.py
new file mode 100644
index 0000000..28603a5
--- /dev/null
+++ b/audio_to_text/captioning/utils/eval_round_robin.py
@@ -0,0 +1,182 @@
+import copy
+import json
+
+import numpy as np
+import fire
+
+
+def evaluate_annotation(key2refs, scorer):
+ if scorer.method() == "Bleu":
+ scores = np.array([ 0.0 for n in range(4) ])
+ else:
+ scores = 0
+ num_cap_per_audio = len(next(iter(key2refs.values())))
+
+ for i in range(num_cap_per_audio):
+ if i > 0:
+ for key in key2refs:
+ key2refs[key].insert(0, res[key][0])
+ res = { key: [refs.pop(),] for key, refs in key2refs.items() }
+ score, _ = scorer.compute_score(key2refs, res)
+
+ if scorer.method() == "Bleu":
+ scores += np.array(score)
+ else:
+ scores += score
+
+ score = scores / num_cap_per_audio
+ return score
+
+def evaluate_prediction(key2pred, key2refs, scorer):
+ if scorer.method() == "Bleu":
+ scores = np.array([ 0.0 for n in range(4) ])
+ else:
+ scores = 0
+ num_cap_per_audio = len(next(iter(key2refs.values())))
+
+ for i in range(num_cap_per_audio):
+ key2refs_i = {}
+ for key, refs in key2refs.items():
+ key2refs_i[key] = refs[:i] + refs[i+1:]
+ score, _ = scorer.compute_score(key2refs_i, key2pred)
+
+ if scorer.method() == "Bleu":
+ scores += np.array(score)
+ else:
+ scores += score
+
+ score = scores / num_cap_per_audio
+ return score
+
+
+class Evaluator(object):
+
+ def eval_annotation(self, annotation, output):
+ captions = json.load(open(annotation, "r"))["audios"]
+
+ key2refs = {}
+ for audio_idx in range(len(captions)):
+ audio_id = captions[audio_idx]["audio_id"]
+ key2refs[audio_id] = []
+ for caption in captions[audio_idx]["captions"]:
+ key2refs[audio_id].append(caption["caption"])
+
+ from fense.fense import Fense
+ scores = {}
+ scorer = Fense()
+ scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer)
+
+ refs4eval = {}
+ for key, refs in key2refs.items():
+ refs4eval[key] = []
+ for idx, ref in enumerate(refs):
+ refs4eval[key].append({
+ "audio_id": key,
+ "id": idx,
+ "caption": ref
+ })
+
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
+
+ tokenizer = PTBTokenizer()
+ key2refs = tokenizer.tokenize(refs4eval)
+
+
+ from pycocoevalcap.bleu.bleu import Bleu
+ from pycocoevalcap.cider.cider import Cider
+ from pycocoevalcap.rouge.rouge import Rouge
+ from pycocoevalcap.meteor.meteor import Meteor
+ from pycocoevalcap.spice.spice import Spice
+
+
+ scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()]
+ for scorer in scorers:
+ scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer)
+
+ spider = 0
+ with open(output, "w") as f:
+ for name, score in scores.items():
+ if name == "Bleu":
+ for n in range(4):
+ f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n]))
+ else:
+ f.write("{}: {:6.3f}\n".format(name, score))
+ if name in ["CIDEr", "SPICE"]:
+ spider += score
+ f.write("SPIDEr: {:6.3f}\n".format(spider / 2))
+
+ def eval_prediction(self, prediction, annotation, output):
+ ref_captions = json.load(open(annotation, "r"))["audios"]
+
+ key2refs = {}
+ for audio_idx in range(len(ref_captions)):
+ audio_id = ref_captions[audio_idx]["audio_id"]
+ key2refs[audio_id] = []
+ for caption in ref_captions[audio_idx]["captions"]:
+ key2refs[audio_id].append(caption["caption"])
+
+ pred_captions = json.load(open(prediction, "r"))["predictions"]
+
+ key2pred = {}
+ for audio_idx in range(len(pred_captions)):
+ item = pred_captions[audio_idx]
+ audio_id = item["filename"]
+ key2pred[audio_id] = [item["tokens"]]
+
+ from fense.fense import Fense
+ scores = {}
+ scorer = Fense()
+ scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer)
+
+ refs4eval = {}
+ for key, refs in key2refs.items():
+ refs4eval[key] = []
+ for idx, ref in enumerate(refs):
+ refs4eval[key].append({
+ "audio_id": key,
+ "id": idx,
+ "caption": ref
+ })
+
+ preds4eval = {}
+ for key, preds in key2pred.items():
+ preds4eval[key] = []
+ for idx, pred in enumerate(preds):
+ preds4eval[key].append({
+ "audio_id": key,
+ "id": idx,
+ "caption": pred
+ })
+
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
+
+ tokenizer = PTBTokenizer()
+ key2refs = tokenizer.tokenize(refs4eval)
+ key2pred = tokenizer.tokenize(preds4eval)
+
+
+ from pycocoevalcap.bleu.bleu import Bleu
+ from pycocoevalcap.cider.cider import Cider
+ from pycocoevalcap.rouge.rouge import Rouge
+ from pycocoevalcap.meteor.meteor import Meteor
+ from pycocoevalcap.spice.spice import Spice
+
+ scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()]
+ for scorer in scorers:
+ scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer)
+
+ spider = 0
+ with open(output, "w") as f:
+ for name, score in scores.items():
+ if name == "Bleu":
+ for n in range(4):
+ f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n]))
+ else:
+ f.write("{}: {:6.3f}\n".format(name, score))
+ if name in ["CIDEr", "SPICE"]:
+ spider += score
+ f.write("SPIDEr: {:6.3f}\n".format(spider / 2))
+
+
+if __name__ == "__main__":
+ fire.Fire(Evaluator)
diff --git a/audio_to_text/captioning/utils/fasttext/create_word_embedding.py b/audio_to_text/captioning/utils/fasttext/create_word_embedding.py
new file mode 100644
index 0000000..09da13a
--- /dev/null
+++ b/audio_to_text/captioning/utils/fasttext/create_word_embedding.py
@@ -0,0 +1,50 @@
+# coding=utf-8
+#!/usr/bin/env python3
+
+import numpy as np
+import pandas as pd
+import torch
+from gensim.models import FastText
+from tqdm import tqdm
+import fire
+
+import sys
+import os
+sys.path.append(os.getcwd())
+from utils.build_vocab import Vocabulary
+
+def create_embedding(caption_file: str,
+ vocab_file: str,
+ embed_size: int,
+ output: str,
+ **fasttext_kwargs):
+ caption_df = pd.read_json(caption_file)
+ caption_df["tokens"] = caption_df["tokens"].apply(lambda x: [""] + [token for token in x] + [""])
+
+ sentences = list(caption_df["tokens"].values)
+ vocabulary = torch.load(vocab_file, map_location="cpu")
+
+ epochs = fasttext_kwargs.get("epochs", 10)
+ model = FastText(size=embed_size, min_count=1, **fasttext_kwargs)
+ model.build_vocab(sentences=sentences)
+ model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs)
+
+ word_embeddings = np.zeros((len(vocabulary), embed_size))
+
+ with tqdm(total=len(vocabulary), ascii=True) as pbar:
+ for word, idx in vocabulary.word2idx.items():
+ if word == "" or word == "":
+ continue
+ word_embeddings[idx] = model.wv[word]
+ pbar.update()
+
+ np.save(output, word_embeddings)
+
+ print("Finish writing fasttext embeddings to " + output)
+
+
+if __name__ == "__main__":
+ fire.Fire(create_embedding)
+
+
+
diff --git a/audio_to_text/captioning/utils/lr_scheduler.py b/audio_to_text/captioning/utils/lr_scheduler.py
new file mode 100644
index 0000000..b46e3f0
--- /dev/null
+++ b/audio_to_text/captioning/utils/lr_scheduler.py
@@ -0,0 +1,128 @@
+import math
+import torch
+
+
+class ExponentialDecayScheduler(torch.optim.lr_scheduler._LRScheduler):
+
+ def __init__(self, optimizer, total_iters, final_lrs,
+ warmup_iters=3000, last_epoch=-1, verbose=False):
+ self.total_iters = total_iters
+ self.final_lrs = final_lrs
+ if not isinstance(self.final_lrs, list) and not isinstance(
+ self.final_lrs, tuple):
+ self.final_lrs = [self.final_lrs] * len(optimizer.param_groups)
+ self.warmup_iters = warmup_iters
+ self.bases = [0.0,] * len(optimizer.param_groups)
+ super().__init__(optimizer, last_epoch, verbose)
+ for i, (base_lr, final_lr) in enumerate(zip(self.base_lrs, self.final_lrs)):
+ base = (final_lr / base_lr) ** (1 / (
+ self.total_iters - self.warmup_iters))
+ self.bases[i] = base
+
+ def _get_closed_form_lr(self):
+ warmup_coeff = 1.0
+ current_iter = self._step_count
+ if current_iter < self.warmup_iters:
+ warmup_coeff = current_iter / self.warmup_iters
+ current_lrs = []
+ # if not self.linear_warmup:
+ # for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs, self.bases):
+ # # current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr))
+ # current_lr = warmup_coeff * base_lr * (base ** (current_iter - self.warmup_iters))
+ # current_lrs.append(current_lr)
+ # else:
+ for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs,
+ self.bases):
+ if current_iter <= self.warmup_iters:
+ current_lr = warmup_coeff * base_lr
+ else:
+ # current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr))
+ current_lr = base_lr * (base ** (current_iter - self.warmup_iters))
+ current_lrs.append(current_lr)
+ return current_lrs
+
+ def get_lr(self):
+ return self._get_closed_form_lr()
+
+
+class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
+
+ def __init__(self, optimizer, model_size=512, factor=1, warmup_iters=3000,
+ last_epoch=-1, verbose=False):
+ self.model_size = model_size
+ self.warmup_iters = warmup_iters
+ # self.factors = [group["lr"] / (self.model_size ** (-0.5) * self.warmup_iters ** (-0.5)) for group in optimizer.param_groups]
+ self.factor = factor
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def _get_closed_form_lr(self):
+ current_iter = self._step_count
+ current_lrs = []
+ for _ in self.base_lrs:
+ current_lr = self.factor * \
+ (self.model_size ** (-0.5) * min(current_iter ** (-0.5),
+ current_iter * self.warmup_iters ** (-1.5)))
+ current_lrs.append(current_lr)
+ return current_lrs
+
+ def get_lr(self):
+ return self._get_closed_form_lr()
+
+
+class CosineWithWarmup(torch.optim.lr_scheduler._LRScheduler):
+
+ def __init__(self, optimizer, total_iters, warmup_iters,
+ num_cycles=0.5, last_epoch=-1, verbose=False):
+ self.total_iters = total_iters
+ self.warmup_iters = warmup_iters
+ self.num_cycles = num_cycles
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def lr_lambda(self, iteration):
+ if iteration < self.warmup_iters:
+ return float(iteration) / float(max(1, self.warmup_iters))
+ progress = float(iteration - self.warmup_iters) / float(max(1,
+ self.total_iters - self.warmup_iters))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(
+ self.num_cycles) * 2.0 * progress)))
+
+ def _get_closed_form_lr(self):
+ current_iter = self._step_count
+ current_lrs = []
+ for base_lr in self.base_lrs:
+ current_lr = base_lr * self.lr_lambda(current_iter)
+ current_lrs.append(current_lr)
+ return current_lrs
+
+ def get_lr(self):
+ return self._get_closed_form_lr()
+
+
+if __name__ == "__main__":
+ model = torch.nn.Linear(10, 5)
+ optimizer = torch.optim.Adam(model.parameters(), 5e-4)
+ epochs = 25
+ iters = 600
+ scheduler = CosineWithWarmup(optimizer, 600 * 25, 600 * 5,)
+ # scheduler = ExponentialDecayScheduler(optimizer, 600 * 25, 5e-7, 600 * 5)
+ criterion = torch.nn.MSELoss()
+ lrs = []
+ for epoch in range(1, epochs + 1):
+ for iteration in range(1, iters + 1):
+ optimizer.zero_grad()
+ x = torch.randn(4, 10)
+ y = torch.randn(4, 5)
+ loss = criterion(model(x), y)
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+ # print(f"lr: {scheduler.get_last_lr()}")
+ # lrs.append(scheduler.get_last_lr())
+ lrs.append(optimizer.param_groups[0]["lr"])
+ import matplotlib.pyplot as plt
+ plt.plot(list(range(1, len(lrs) + 1)), lrs, '-o', markersize=1)
+ # plt.legend(loc="best")
+ plt.xlabel("Iteration")
+ plt.ylabel("LR")
+
+ plt.savefig("lr_curve.png", dpi=100)
diff --git a/audio_to_text/captioning/utils/model_eval_diff.py b/audio_to_text/captioning/utils/model_eval_diff.py
new file mode 100644
index 0000000..2c29ef8
--- /dev/null
+++ b/audio_to_text/captioning/utils/model_eval_diff.py
@@ -0,0 +1,110 @@
+import os
+import sys
+import copy
+import pickle
+
+import numpy as np
+import pandas as pd
+import fire
+
+sys.path.append(os.getcwd())
+
+
+def coco_score(refs, pred, scorer):
+ if scorer.method() == "Bleu":
+ scores = np.array([ 0.0 for n in range(4) ])
+ else:
+ scores = 0
+ num_cap_per_audio = len(refs[list(refs.keys())[0]])
+
+ for i in range(num_cap_per_audio):
+ if i > 0:
+ for key in refs:
+ refs[key].insert(0, res[key][0])
+ res = {key: [refs[key].pop(),] for key in refs}
+ score, _ = scorer.compute_score(refs, pred)
+
+ if scorer.method() == "Bleu":
+ scores += np.array(score)
+ else:
+ scores += score
+
+ score = scores / num_cap_per_audio
+
+ for key in refs:
+ refs[key].insert(0, res[key][0])
+ score_allref, _ = scorer.compute_score(refs, pred)
+ diff = score_allref - score
+ return diff
+
+def embedding_score(refs, pred, scorer):
+
+ num_cap_per_audio = len(refs[list(refs.keys())[0]])
+ scores = 0
+
+ for i in range(num_cap_per_audio):
+ res = {key: [refs[key][i],] for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
+ refs_i = {key: np.concatenate([refs[key][:i], refs[key][i+1:]]) for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
+ score, _ = scorer.compute_score(refs_i, pred)
+
+ scores += score
+
+ score = scores / num_cap_per_audio
+
+ score_allref, _ = scorer.compute_score(refs, pred)
+ diff = score_allref - score
+ return diff
+
+def main(output_file, eval_caption_file, eval_embedding_file, output, zh=False):
+ output_df = pd.read_json(output_file)
+ output_df["key"] = output_df["filename"].apply(lambda x: os.path.splitext(os.path.basename(x))[0])
+ pred = output_df.groupby("key")["tokens"].apply(list).to_dict()
+
+ label_df = pd.read_json(eval_caption_file)
+ if zh:
+ refs = label_df.groupby("key")["tokens"].apply(list).to_dict()
+ else:
+ refs = label_df.groupby("key")["caption"].apply(list).to_dict()
+
+ from pycocoevalcap.bleu.bleu import Bleu
+ from pycocoevalcap.cider.cider import Cider
+ from pycocoevalcap.rouge.rouge import Rouge
+
+ scorer = Bleu(zh=zh)
+ bleu_scores = coco_score(copy.deepcopy(refs), pred, scorer)
+ scorer = Cider(zh=zh)
+ cider_score = coco_score(copy.deepcopy(refs), pred, scorer)
+ scorer = Rouge(zh=zh)
+ rouge_score = coco_score(copy.deepcopy(refs), pred, scorer)
+
+ if not zh:
+ from pycocoevalcap.meteor.meteor import Meteor
+ scorer = Meteor()
+ meteor_score = coco_score(copy.deepcopy(refs), pred, scorer)
+
+ from pycocoevalcap.spice.spice import Spice
+ scorer = Spice()
+ spice_score = coco_score(copy.deepcopy(refs), pred, scorer)
+
+ # from audiocaptioneval.sentbert.sentencebert import SentenceBert
+ # scorer = SentenceBert(zh=zh)
+ # with open(eval_embedding_file, "rb") as f:
+ # ref_embeddings = pickle.load(f)
+
+ # sent_bert = embedding_score(ref_embeddings, pred, scorer)
+
+ with open(output, "w") as f:
+ f.write("Diff:\n")
+ for n in range(4):
+ f.write("BLEU-{}: {:6.3f}\n".format(n+1, bleu_scores[n]))
+ f.write("CIDEr: {:6.3f}\n".format(cider_score))
+ f.write("ROUGE: {:6.3f}\n".format(rouge_score))
+ if not zh:
+ f.write("Meteor: {:6.3f}\n".format(meteor_score))
+ f.write("SPICE: {:6.3f}\n".format(spice_score))
+ # f.write("SentenceBert: {:6.3f}\n".format(sent_bert))
+
+
+
+if __name__ == "__main__":
+ fire.Fire(main)
diff --git a/audio_to_text/captioning/utils/predict_nn.py b/audio_to_text/captioning/utils/predict_nn.py
new file mode 100644
index 0000000..699c3dc
--- /dev/null
+++ b/audio_to_text/captioning/utils/predict_nn.py
@@ -0,0 +1,49 @@
+import json
+import random
+import argparse
+import numpy as np
+from tqdm import tqdm
+from h5py import File
+import sklearn.metrics
+
+random.seed(1)
+
+parser = argparse.ArgumentParser()
+parser.add_argument("train_feature", type=str)
+parser.add_argument("train_corpus", type=str)
+parser.add_argument("pred_feature", type=str)
+parser.add_argument("output_json", type=str)
+
+args = parser.parse_args()
+train_embs = []
+train_idx_to_audioid = []
+with File(args.train_feature, "r") as store:
+ for audio_id, embedding in tqdm(store.items(), ascii=True):
+ train_embs.append(embedding[()])
+ train_idx_to_audioid.append(audio_id)
+
+train_annotation = json.load(open(args.train_corpus, "r"))["audios"]
+train_audioid_to_tokens = {}
+for item in train_annotation:
+ audio_id = item["audio_id"]
+ train_audioid_to_tokens[audio_id] = [cap_item["tokens"] for cap_item in item["captions"]]
+train_embs = np.stack(train_embs)
+
+
+pred_data = []
+pred_embs = []
+pred_idx_to_audioids = []
+with File(args.pred_feature, "r") as store:
+ for audio_id, embedding in tqdm(store.items(), ascii=True):
+ pred_embs.append(embedding[()])
+ pred_idx_to_audioids.append(audio_id)
+pred_embs = np.stack(pred_embs)
+
+similarity = sklearn.metrics.pairwise.cosine_similarity(pred_embs, train_embs)
+for idx, audio_id in enumerate(pred_idx_to_audioids):
+ train_idx = similarity[idx].argmax()
+ pred_data.append({
+ "filename": audio_id,
+ "tokens": random.choice(train_audioid_to_tokens[train_idx_to_audioid[train_idx]])
+ })
+json.dump({"predictions": pred_data}, open(args.output_json, "w"), ensure_ascii=False, indent=4)
diff --git a/audio_to_text/captioning/utils/remove_optimizer.py b/audio_to_text/captioning/utils/remove_optimizer.py
new file mode 100644
index 0000000..2b9871e
--- /dev/null
+++ b/audio_to_text/captioning/utils/remove_optimizer.py
@@ -0,0 +1,18 @@
+import argparse
+import torch
+
+
+def main(checkpoint):
+ state_dict = torch.load(checkpoint, map_location="cpu")
+ if "optimizer" in state_dict:
+ del state_dict["optimizer"]
+ if "lr_scheduler" in state_dict:
+ del state_dict["lr_scheduler"]
+ torch.save(state_dict, checkpoint)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("checkpoint", type=str)
+ args = parser.parse_args()
+ main(args.checkpoint)
diff --git a/audio_to_text/captioning/utils/report_results.py b/audio_to_text/captioning/utils/report_results.py
new file mode 100644
index 0000000..3b9f6ec
--- /dev/null
+++ b/audio_to_text/captioning/utils/report_results.py
@@ -0,0 +1,37 @@
+from pathlib import Path
+import argparse
+import numpy as np
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--input", help="input filename", type=str, nargs="+")
+parser.add_argument("--output", help="output result file", default=None)
+
+args = parser.parse_args()
+
+
+scores = {}
+for path in args.input:
+ with open(path, "r") as reader:
+ for line in reader.readlines():
+ metric, score = line.strip().split(": ")
+ score = float(score)
+ if metric not in scores:
+ scores[metric] = []
+ scores[metric].append(score)
+
+if len(scores) == 0:
+ print("No experiment directory found, wrong path?")
+ exit(1)
+
+with open(args.output, "w") as writer:
+ print("Average results: ", file=writer)
+ for metric, score in scores.items():
+ score = np.array(score)
+ mean = np.mean(score)
+ std = np.std(score)
+ print(f"{metric}: {mean:.3f} (±{std:.3f})", file=writer)
+ print("", file=writer)
+ print("Best results: ", file=writer)
+ for metric, score in scores.items():
+ score = np.max(score)
+ print(f"{metric}: {score:.3f}", file=writer)
diff --git a/audio_to_text/captioning/utils/tokenize_caption.py b/audio_to_text/captioning/utils/tokenize_caption.py
new file mode 100644
index 0000000..b340068
--- /dev/null
+++ b/audio_to_text/captioning/utils/tokenize_caption.py
@@ -0,0 +1,86 @@
+import json
+from tqdm import tqdm
+import re
+import fire
+
+
+def tokenize_caption(input_json: str,
+ keep_punctuation: bool = False,
+ host_address: str = None,
+ character_level: bool = False,
+ zh: bool = True,
+ output_json: str = None):
+ """Build vocabulary from csv file with a given threshold to drop all counts < threshold
+
+ Args:
+ input_json(string): Preprossessed json file. Structure like this:
+ {
+ 'audios': [
+ {
+ 'audio_id': 'xxx',
+ 'captions': [
+ {
+ 'caption': 'xxx',
+ 'cap_id': 'xxx'
+ }
+ ]
+ },
+ ...
+ ]
+ }
+ threshold (int): Threshold to drop all words with counts < threshold
+ keep_punctuation (bool): Includes or excludes punctuation.
+
+ Returns:
+ vocab (Vocab): Object with the processed vocabulary
+"""
+ data = json.load(open(input_json, "r"))["audios"]
+
+ if zh:
+ from nltk.parse.corenlp import CoreNLPParser
+ from zhon.hanzi import punctuation
+ parser = CoreNLPParser(host_address)
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
+ # Remove all punctuations
+ if not keep_punctuation:
+ caption = re.sub("[{}]".format(punctuation), "", caption)
+ if character_level:
+ tokens = list(caption)
+ else:
+ tokens = list(parser.tokenize(caption))
+ data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
+ else:
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
+ captions = {}
+ for audio_idx in range(len(data)):
+ audio_id = data[audio_idx]["audio_id"]
+ captions[audio_id] = []
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
+ captions[audio_id].append({
+ "audio_id": audio_id,
+ "id": cap_idx,
+ "caption": caption
+ })
+ tokenizer = PTBTokenizer()
+ captions = tokenizer.tokenize(captions)
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
+ audio_id = data[audio_idx]["audio_id"]
+ for cap_idx in range(len(data[audio_idx]["captions"])):
+ tokens = captions[audio_id][cap_idx]
+ data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
+
+ if output_json:
+ json.dump(
+ { "audios": data }, open(output_json, "w"),
+ indent=4, ensure_ascii=not zh)
+ else:
+ json.dump(
+ { "audios": data }, open(input_json, "w"),
+ indent=4, ensure_ascii=not zh)
+
+
+if __name__ == "__main__":
+ fire.Fire(tokenize_caption)
diff --git a/audio_to_text/captioning/utils/train_util.py b/audio_to_text/captioning/utils/train_util.py
new file mode 100644
index 0000000..6cd62cc
--- /dev/null
+++ b/audio_to_text/captioning/utils/train_util.py
@@ -0,0 +1,178 @@
+# -*- coding: utf-8 -*-
+#!/usr/bin/env python3
+import os
+import sys
+import logging
+from typing import Callable, Dict, Union
+import yaml
+import torch
+from torch.optim.swa_utils import AveragedModel as torch_average_model
+import numpy as np
+import pandas as pd
+from pprint import pformat
+
+
+def load_dict_from_csv(csv, cols):
+ df = pd.read_csv(csv, sep="\t")
+ output = dict(zip(df[cols[0]], df[cols[1]]))
+ return output
+
+
+def init_logger(filename, level="INFO"):
+ formatter = logging.Formatter(
+ "[ %(levelname)s : %(asctime)s ] - %(message)s")
+ logger = logging.getLogger(__name__ + "." + filename)
+ logger.setLevel(getattr(logging, level))
+ # Log results to std
+ # stdhandler = logging.StreamHandler(sys.stdout)
+ # stdhandler.setFormatter(formatter)
+ # Dump log to file
+ filehandler = logging.FileHandler(filename)
+ filehandler.setFormatter(formatter)
+ logger.addHandler(filehandler)
+ # logger.addHandler(stdhandler)
+ return logger
+
+
+def init_obj(module, config, **kwargs):# 'captioning.models.encoder'
+ obj_args = config["args"].copy()
+ obj_args.update(kwargs)
+ return getattr(module, config["type"])(**obj_args)
+
+
+def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'):
+ """pprint_dict
+
+ :param outputfun: function to use, defaults to sys.stdout
+ :param in_dict: dict to print
+ """
+ if formatter == 'yaml':
+ format_fun = yaml.dump
+ elif formatter == 'pretty':
+ format_fun = pformat
+ for line in format_fun(in_dict).split('\n'):
+ outputfun(line)
+
+
+def merge_a_into_b(a, b):
+ # merge dict a into dict b. values in a will overwrite b.
+ for k, v in a.items():
+ if isinstance(v, dict) and k in b:
+ assert isinstance(
+ b[k], dict
+ ), "Cannot inherit key '{}' from base!".format(k)
+ merge_a_into_b(v, b[k])
+ else:
+ b[k] = v
+
+
+def load_config(config_file):
+ with open(config_file, "r") as reader:
+ config = yaml.load(reader, Loader=yaml.FullLoader)
+ if "inherit_from" in config:
+ base_config_file = config["inherit_from"]
+ base_config_file = os.path.join(
+ os.path.dirname(config_file), base_config_file
+ )
+ assert not os.path.samefile(config_file, base_config_file), \
+ "inherit from itself"
+ base_config = load_config(base_config_file)
+ del config["inherit_from"]
+ merge_a_into_b(config, base_config)
+ return base_config
+ return config
+
+
+def parse_config_or_kwargs(config_file, **kwargs):
+ yaml_config = load_config(config_file)
+ # passed kwargs will override yaml config
+ args = dict(yaml_config, **kwargs)
+ return args
+
+
+def store_yaml(config, config_file):
+ with open(config_file, "w") as con_writer:
+ yaml.dump(config, con_writer, indent=4, default_flow_style=False)
+
+
+class MetricImprover:
+
+ def __init__(self, mode):
+ assert mode in ("min", "max")
+ self.mode = mode
+ # min: lower -> better; max: higher -> better
+ self.best_value = np.inf if mode == "min" else -np.inf
+
+ def compare(self, x, best_x):
+ return x < best_x if self.mode == "min" else x > best_x
+
+ def __call__(self, x):
+ if self.compare(x, self.best_value):
+ self.best_value = x
+ return True
+ return False
+
+ def state_dict(self):
+ return self.__dict__
+
+ def load_state_dict(self, state_dict):
+ self.__dict__.update(state_dict)
+
+
+def fix_batchnorm(model: torch.nn.Module):
+ def inner(module):
+ class_name = module.__class__.__name__
+ if class_name.find("BatchNorm") != -1:
+ module.eval()
+ model.apply(inner)
+
+
+def load_pretrained_model(model: torch.nn.Module,
+ pretrained: Union[str, Dict],
+ output_fn: Callable = sys.stdout.write):
+ if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
+ output_fn(f"pretrained {pretrained} not exist!")
+ return
+
+ if hasattr(model, "load_pretrained"):
+ model.load_pretrained(pretrained)
+ return
+
+ if isinstance(pretrained, dict):
+ state_dict = pretrained
+ else:
+ state_dict = torch.load(pretrained, map_location="cpu")
+
+ if "model" in state_dict:
+ state_dict = state_dict["model"]
+ model_dict = model.state_dict()
+ pretrained_dict = {
+ k: v for k, v in state_dict.items() if (k in model_dict) and (
+ model_dict[k].shape == v.shape)
+ }
+ output_fn(f"Loading pretrained keys {pretrained_dict.keys()}")
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict, strict=True)
+
+
+class AveragedModel(torch_average_model):
+
+ def update_parameters(self, model):
+ for p_swa, p_model in zip(self.parameters(), model.parameters()):
+ device = p_swa.device
+ p_model_ = p_model.detach().to(device)
+ if self.n_averaged == 0:
+ p_swa.detach().copy_(p_model_)
+ else:
+ p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
+ self.n_averaged.to(device)))
+
+ for b_swa, b_model in zip(list(self.buffers())[1:], model.buffers()):
+ device = b_swa.device
+ b_model_ = b_model.detach().to(device)
+ if self.n_averaged == 0:
+ b_swa.detach().copy_(b_model_)
+ else:
+ b_swa.detach().copy_(self.avg_fn(b_swa.detach(), b_model_,
+ self.n_averaged.to(device)))
+ self.n_averaged += 1
diff --git a/audio_to_text/captioning/utils/word2vec/create_word_embedding.py b/audio_to_text/captioning/utils/word2vec/create_word_embedding.py
new file mode 100644
index 0000000..77ebe5a
--- /dev/null
+++ b/audio_to_text/captioning/utils/word2vec/create_word_embedding.py
@@ -0,0 +1,67 @@
+# coding=utf-8
+#!/usr/bin/env python3
+
+import numpy as np
+import pandas as pd
+import torch
+import gensim
+from gensim.models import Word2Vec
+from tqdm import tqdm
+import fire
+
+import sys
+import os
+sys.path.append(os.getcwd())
+from utils.build_vocab import Vocabulary
+
+def create_embedding(vocab_file: str,
+ embed_size: int,
+ output: str,
+ caption_file: str = None,
+ pretrained_weights_path: str = None,
+ **word2vec_kwargs):
+ vocabulary = torch.load(vocab_file, map_location="cpu")
+
+ if pretrained_weights_path:
+ model = gensim.models.KeyedVectors.load_word2vec_format(
+ fname=pretrained_weights_path,
+ binary=True,
+ )
+ if model.vector_size != embed_size:
+ assert embed_size < model.vector_size, f"only reduce dimension, cannot add dimesion {model.vector_size} to {embed_size}"
+ from sklearn.decomposition import PCA
+ pca = PCA(n_components=embed_size)
+ model.vectors = pca.fit_transform(model.vectors)
+ else:
+ caption_df = pd.read_json(caption_file)
+ caption_df["tokens"] = caption_df["tokens"].apply(lambda x: [""] + [token for token in x] + [""])
+ sentences = list(caption_df["tokens"].values)
+ epochs = word2vec_kwargs.get("epochs", 10)
+ if "epochs" in word2vec_kwargs:
+ del word2vec_kwargs["epochs"]
+ model = Word2Vec(size=embed_size, min_count=1, **word2vec_kwargs)
+ model.build_vocab(sentences=sentences)
+ model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs)
+
+ word_embeddings = np.random.randn(len(vocabulary), embed_size)
+
+ if isinstance(model, gensim.models.word2vec.Word2Vec):
+ model = model.wv
+ with tqdm(total=len(vocabulary), ascii=True) as pbar:
+ for word, idx in vocabulary.word2idx.items():
+ try:
+ word_embeddings[idx] = model.get_vector(word)
+ except KeyError:
+ print(f"word {word} not found in word2vec model, it is random initialized!")
+ pbar.update()
+
+ np.save(output, word_embeddings)
+
+ print("Finish writing word2vec embeddings to " + output)
+
+
+if __name__ == "__main__":
+ fire.Fire(create_embedding)
+
+
+
diff --git a/audio_to_text/inference_waveform.py b/audio_to_text/inference_waveform.py
new file mode 100644
index 0000000..aba3961
--- /dev/null
+++ b/audio_to_text/inference_waveform.py
@@ -0,0 +1,102 @@
+import sys
+import os
+import librosa
+import numpy as np
+import torch
+import audio_to_text.captioning.models
+import audio_to_text.captioning.models.encoder
+import audio_to_text.captioning.models.decoder
+import audio_to_text.captioning.utils.train_util as train_util
+
+
+def load_model(config, checkpoint):
+ ckpt = torch.load(checkpoint, "cpu")
+ encoder_cfg = config["model"]["encoder"]
+ encoder = train_util.init_obj(
+ audio_to_text.captioning.models.encoder,
+ encoder_cfg
+ )
+ if "pretrained" in encoder_cfg:
+ pretrained = encoder_cfg["pretrained"]
+ train_util.load_pretrained_model(encoder,
+ pretrained,
+ sys.stdout.write)
+ decoder_cfg = config["model"]["decoder"]
+ if "vocab_size" not in decoder_cfg["args"]:
+ decoder_cfg["args"]["vocab_size"] = len(ckpt["vocabulary"])
+ decoder = train_util.init_obj(
+ audio_to_text.captioning.models.decoder,
+ decoder_cfg
+ )
+ if "word_embedding" in decoder_cfg:
+ decoder.load_word_embedding(**decoder_cfg["word_embedding"])
+ if "pretrained" in decoder_cfg:
+ pretrained = decoder_cfg["pretrained"]
+ train_util.load_pretrained_model(decoder,
+ pretrained,
+ sys.stdout.write)
+ model = train_util.init_obj(audio_to_text.captioning.models, config["model"],
+ encoder=encoder, decoder=decoder)
+ train_util.load_pretrained_model(model, ckpt)
+ model.eval()
+ return {
+ "model": model,
+ "vocabulary": ckpt["vocabulary"]
+ }
+
+
+def decode_caption(word_ids, vocabulary):
+ candidate = []
+ for word_id in word_ids:
+ word = vocabulary[word_id]
+ if word == "":
+ break
+ elif word == "":
+ continue
+ candidate.append(word)
+ candidate = " ".join(candidate)
+ return candidate
+
+
+class AudioCapModel(object):
+ def __init__(self,weight_dir,device='cuda'):
+ config = os.path.join(weight_dir,'config.yaml')
+ self.config = train_util.parse_config_or_kwargs(config)
+ checkpoint = os.path.join(weight_dir,'swa.pth')
+ resumed = load_model(self.config, checkpoint)
+ model = resumed["model"]
+ self.vocabulary = resumed["vocabulary"]
+ self.model = model.to(device)
+ self.device = device
+
+ def caption(self,audio_list):
+ if isinstance(audio_list,np.ndarray):
+ audio_list = [audio_list]
+ elif isinstance(audio_list,str):
+ audio_list = [librosa.load(audio_list,sr=32000)[0]]
+
+ captions = []
+ for wav in audio_list:
+ inputwav = torch.as_tensor(wav).float().unsqueeze(0).to(self.device)
+ wav_len = torch.LongTensor([len(wav)])
+ input_dict = {
+ "mode": "inference",
+ "wav": inputwav,
+ "wav_len": wav_len,
+ "specaug": False,
+ "sample_method": "beam",
+ }
+ print(input_dict)
+ out_dict = self.model(input_dict)
+ caption_batch = [decode_caption(seq, self.vocabulary) for seq in \
+ out_dict["seq"].cpu().numpy()]
+ captions.extend(caption_batch)
+ return captions
+
+
+
+ def __call__(self, audio_list):
+ return self.caption(audio_list)
+
+
+
diff --git a/download.sh b/download.sh
index 8de681a..e6be9cc 100644
--- a/download.sh
+++ b/download.sh
@@ -26,4 +26,9 @@ wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/Gene
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/spk_map.json
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/train_f0s_mean_std.npy
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/word_set.json
-
+wget -P text_to_speech/checkpoints/hifi_lj -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/text_to_speech/checkpoints/hifi_lj/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/text_to_speech/checkpoints/hifi_lj/model_ckpt_steps_2076000.ckpt
+wget -P text_to_speech/checkpoints/ljspeech/ps_adv_baseline -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/text_to_speech/checkpoints/ljspeech/ps_adv_baseline/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/checkpoints/ljspeech/ps_adv_baseline/model_ckpt_steps_160000.ckpt https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/checkpoints/ljspeech/ps_adv_baseline/model_ckpt_steps_160001.ckpt
+# Audio to text
+wget -P audio_to_text/audiocaps_cntrstv_cnn14rnn_trm -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/swa.pth
+wget -P audio_to_text/clotho_cntrstv_cnn14rnn_trm -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/audio_to_text/clotho_cntrstv_cnn14rnn_trm/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/clotho_cntrstv_cnn14rnn_trm/swa.pth
+wget -P audio_to_text/pretrained_feature_extractors https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth
\ No newline at end of file