mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 20:07:58 +01:00
Merge pull request #6 from lmzjms/main
update audio-caption, text to speech
This commit is contained in:
BIN
assets/2bf90e35.wav
Normal file
BIN
assets/2bf90e35.wav
Normal file
Binary file not shown.
BIN
assets/5d67d1b9.wav
Normal file
BIN
assets/5d67d1b9.wav
Normal file
Binary file not shown.
@@ -7,21 +7,45 @@ Output:<br />
|
||||
Input Example : Generate an audio of a piano playing<br />
|
||||
Output:<br />
|
||||
<br />
|
||||
Audio:<br />
|
||||
<audio src="b973e878.wav" controls></audio><br />
|
||||
|
||||
## Text-To-Speech
|
||||
Input Example : Generate a speech with text "here we go"<br />
|
||||
Output:<br />
|
||||
<br />
|
||||
Audio:<br />
|
||||
<audio src="fd5cf55e.wav" controls></audio><br />
|
||||
|
||||
## Text-To-Sing
|
||||
Input example : please generate a piece of singing voice. Text sequence is 小酒窝长睫毛AP是你最美的记号. Note sequence is C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4. Note duration sequence is 0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340.<br />
|
||||
Output:<br />
|
||||
<br />
|
||||
Audio:<br />
|
||||
<audio src="2bf90e35.wav" controls></audio><br />
|
||||
## Image-To-Audio
|
||||
First upload your image(.png)<br />
|
||||
Input Example : Generate the audio of this image<br />
|
||||
Output:<br />
|
||||
<br />
|
||||
## ASR
|
||||
Audio:<br />
|
||||
<audio src="5d67d1b9.wav" controls></audio><br />
|
||||
|
||||
## Speech Recognition
|
||||
First upload your audio(.wav)<br />
|
||||
Input Example : Generate the text of this audio<br />
|
||||
Audio Example :<br />
|
||||
<audio src="Track 4.wav" controls></audio><br />
|
||||
Input Example : Generate the text of this speech<br />
|
||||
Output:<br />
|
||||
<br />
|
||||
|
||||
## Audio-To-Text
|
||||
First upload your audio(.wav)<br />
|
||||
Audio Example :<br />
|
||||
<audio src="a-group-of-sheep-are-baaing.wav" controls></audio><br />
|
||||
Input Example : Please tell me the text description of this audio.<br />
|
||||
Output:<br />
|
||||
<br />
|
||||
## Style Transfer Text-To-Speech
|
||||
First upload your audio(.wav)<br />
|
||||
Input Example : Speak using the voice of this audio. The text is "here we go".<br />
|
||||
|
||||
BIN
assets/Track 4.wav
Normal file
BIN
assets/Track 4.wav
Normal file
Binary file not shown.
BIN
assets/a-group-of-sheep-are-baaing.wav
Normal file
BIN
assets/a-group-of-sheep-are-baaing.wav
Normal file
Binary file not shown.
BIN
assets/a2i.png
Normal file
BIN
assets/a2i.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 22 KiB |
BIN
assets/b973e878.wav
Normal file
BIN
assets/b973e878.wav
Normal file
Binary file not shown.
BIN
assets/fd5cf55e.wav
Normal file
BIN
assets/fd5cf55e.wav
Normal file
Binary file not shown.
BIN
assets/tts.png
Normal file
BIN
assets/tts.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 58 KiB |
116
audio-chatgpt.py
116
audio-chatgpt.py
@@ -18,7 +18,6 @@ from langchain.llms.openai import OpenAI
|
||||
import re
|
||||
import uuid
|
||||
import soundfile
|
||||
from scipy.io import wavfile
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
@@ -36,20 +35,22 @@ from vocoder.bigvgan.models import VocoderBigVGAN
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from wav_evaluation.models.CLAPWrapper import CLAPWrapper
|
||||
from inference.svs.ds_e2e import DiffSingerE2EInfer
|
||||
from audio_to_text.inference_waveform import AudioCapModel
|
||||
import whisper
|
||||
from text_to_speech.TTS_binding import TTSInference
|
||||
|
||||
import torch
|
||||
from inference.svs.ds_e2e import DiffSingerE2EInfer
|
||||
from inference.tts.GenerSpeech import GenerSpeechInfer
|
||||
from utils.hparams import set_hparams
|
||||
from utils.hparams import hparams as hp
|
||||
from utils.os_utils import move_file
|
||||
import scipy.io.wavfile as wavfile
|
||||
|
||||
AUDIO_CHATGPT_PREFIX = """Audio ChatGPT
|
||||
AUdio ChatGPT can not directly read audios, but it has a list of tools to finish different audio synthesis tasks. Each audio will have a file name formed as "audio/xxx.wav". When talking about audios, Audio ChatGPT is very strict to the file name and will never fabricate nonexistent files.
|
||||
AUdio ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the audio content and audio file name. It will remember to provide the file name from the last tool observation, if a new audio is generated.
|
||||
Human may provide Audio ChatGPT with a description. Audio ChatGPT should generate audios according to this description rather than directly imagine from memory or yourself."
|
||||
|
||||
|
||||
TOOLS:
|
||||
------
|
||||
|
||||
@@ -87,7 +88,7 @@ Thought: Do I need to use a tool? {agent_scratchpad}"""
|
||||
def cut_dialogue_history(history_memory, keep_last_n_words = 500):
|
||||
tokens = history_memory.split()
|
||||
n_tokens = len(tokens)
|
||||
print(f"hitory_memory:{history_memory}, n_tokens: {n_tokens}")
|
||||
print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
|
||||
if n_tokens < keep_last_n_words:
|
||||
return history_memory
|
||||
else:
|
||||
@@ -125,33 +126,6 @@ def select_best_audio(prompt,wav_list):
|
||||
print(score_list,max_index)
|
||||
return wav_list[max_index]
|
||||
|
||||
class MaskFormer:
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
|
||||
|
||||
def inference(self, image_path, text):
|
||||
threshold = 0.5
|
||||
min_area = 0.02
|
||||
padding = 20
|
||||
original_image = Image.open(image_path)
|
||||
image = original_image.resize((512, 512))
|
||||
inputs = self.processor(text=text, images=image, padding="max_length", return_tensors="pt",).to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
|
||||
area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
|
||||
if area_ratio < min_area:
|
||||
return None
|
||||
true_indices = np.argwhere(mask)
|
||||
mask_array = np.zeros_like(mask, dtype=bool)
|
||||
for idx in true_indices:
|
||||
padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx)
|
||||
mask_array[padded_slice] = True
|
||||
visual_mask = (mask_array * 255).astype(np.uint8)
|
||||
image_mask = Image.fromarray(visual_mask)
|
||||
return image_mask.resize(image.size)
|
||||
|
||||
|
||||
class T2I:
|
||||
@@ -191,7 +165,7 @@ class T2A:
|
||||
print("Initializing Make-An-Audio to %s" % device)
|
||||
self.device = device
|
||||
self.sampler = initialize_model('configs/text-to-audio/txt2audio_args.yaml', 'useful_ckpts/ta40multi_epoch=000085.ckpt', device=device)
|
||||
self.vocoder = VocoderHifigan('vocoder/logs/hifi_0127',device=device)
|
||||
self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio/vocoder/logs/bigv16k53w',device=device)
|
||||
|
||||
def txt2audio(self, text, seed = 55, scale = 1.5, ddim_steps = 100, n_samples = 3, W = 624, H = 80):
|
||||
SAMPLE_RATE = 16000
|
||||
@@ -250,7 +224,7 @@ class I2A:
|
||||
image = Image.open(image)
|
||||
image = self.sampler.model.cond_stage_model.preprocess(image).unsqueeze(0)
|
||||
image_embedding = self.sampler.model.cond_stage_model.forward_img(image)
|
||||
c = image_embedding.repeat(n_samples, 1, 1)
|
||||
c = image_embedding.repeat(n_samples, 1, 1)# shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding
|
||||
shape = [self.sampler.model.first_stage_model.embed_dim, H//8, W//8] # (z_dim, 80//2^x, 848//2^x)
|
||||
samples_ddim, _ = self.sampler.sample(S=ddim_steps,
|
||||
conditioning=c,
|
||||
@@ -291,7 +265,6 @@ class TTS:
|
||||
inp = {"text": text}
|
||||
out = self.inferencer.infer_once(inp)
|
||||
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
||||
temp_audio_filename = audio_filename
|
||||
soundfile.write(audio_filename, out, samplerate = 22050)
|
||||
return audio_filename
|
||||
|
||||
@@ -305,7 +278,7 @@ class T2S:
|
||||
self.config= 'text_to_sing/DiffSinger/usr/configs/midi/e2e/opencpop/ds1000.yaml'
|
||||
self.set_model_hparams()
|
||||
self.pipe = DiffSingerE2EInfer(self.hp, device)
|
||||
self.defualt_inp = {
|
||||
self.default_inp = {
|
||||
'text': '你 说 你 不 SP 懂 为 何 在 这 时 牵 手 AP',
|
||||
'notes': 'D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | rest | D#4/Eb4 | D4 | D4 | D4 | D#4/Eb4 | F4 | D#4/Eb4 | D4 | rest',
|
||||
'notes_duration': '0.113740 | 0.329060 | 0.287950 | 0.133480 | 0.150900 | 0.484730 | 0.242010 | 0.180820 | 0.343570 | 0.152050 | 0.266720 | 0.280310 | 0.633300 | 0.444590'
|
||||
@@ -320,7 +293,7 @@ class T2S:
|
||||
val = inputs.split(",")
|
||||
key = ['text', 'notes', 'notes_duration']
|
||||
if inputs == '' or len(val) < len(key):
|
||||
inp = self.defualt_inp
|
||||
inp = self.default_inp
|
||||
else:
|
||||
inp = {k:v for k,v in zip(key,val)}
|
||||
wav = self.pipe.infer_once(inp)
|
||||
@@ -356,6 +329,7 @@ class TTS_OOD:
|
||||
key = ['ref_audio', 'text']
|
||||
val = inputs.split(",")
|
||||
inp = {k: v for k, v in zip(key, val)}
|
||||
print(inp)
|
||||
wav = self.pipe.infer_once(inp)
|
||||
wav *= 32767
|
||||
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
||||
@@ -364,16 +338,13 @@ class TTS_OOD:
|
||||
f"Processed GenerSpeech.run. Input text:{val[1]}. Input reference audio: {val[0]}. Output Audio_filename: {audio_filename}")
|
||||
return audio_filename
|
||||
|
||||
|
||||
class Inpaint:
|
||||
def __init__(self, device):
|
||||
print("Initializing Make-An-Audio-inpaint to %s" % device)
|
||||
self.device = device
|
||||
self.sampler = initialize_model('text_to_audio/Make_An_Audio_inpaint/configs/inpaint/txt2audio_args.yaml',
|
||||
'text_to_audio/Make_An_Audio_inpaint/useful_ckpts/inpaint7_epoch00047.ckpt')
|
||||
self.sampler = initialize_model('text_to_audio/Make_An_Audio_inpaint/configs/inpaint/txt2audio_args.yaml', 'text_to_audio/Make_An_Audio_inpaint/useful_ckpts/inpaint7_epoch00047.ckpt')
|
||||
self.vocoder = VocoderBigVGAN('./vocoder/logs/bigv16k53w',device=device)
|
||||
|
||||
def make_batch_sd(self, mel, mask, num_samples=1):
|
||||
def make_batch_sd(mel, mask, num_samples=1):
|
||||
|
||||
mel = torch.from_numpy(mel)[None,None,...].to(dtype=torch.float32)
|
||||
mask = torch.from_numpy(mask)[None,None,...].to(dtype=torch.float32)
|
||||
@@ -389,15 +360,13 @@ class Inpaint:
|
||||
"masked_mel": repeat(masked_mel.to(device=self.device), "1 ... -> n ...", n=num_samples),
|
||||
}
|
||||
return batch
|
||||
|
||||
def gen_mel(self, input_audio):
|
||||
def gen_mel(input_audio):
|
||||
sr,ori_wav = input_audio
|
||||
print(sr,ori_wav.shape,ori_wav)
|
||||
|
||||
ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0 # order='C'是以C语言格式存储,不用管
|
||||
if len(ori_wav.shape)==2:# stereo
|
||||
ori_wav = librosa.to_mono(
|
||||
ori_wav.T) # gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len)
|
||||
ori_wav = librosa.to_mono(ori_wav.T)# gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len)
|
||||
print(sr,ori_wav.shape,ori_wav)
|
||||
ori_wav = librosa.resample(ori_wav,orig_sr = sr,target_sr = SAMPLE_RATE)
|
||||
|
||||
@@ -410,14 +379,12 @@ class Inpaint:
|
||||
|
||||
mel = TRANSFORMS_16000(input_wav)
|
||||
return mel
|
||||
|
||||
def show_mel_fn(self, input_audio):
|
||||
def show_mel_fn(input_audio):
|
||||
crop_len = 500 # the full mel cannot be showed due to gradio's Image bug when using tool='sketch'
|
||||
crop_mel = self.gen_mel(input_audio)[:,:crop_len]
|
||||
color_mel = cmap_transform(crop_mel)
|
||||
return Image.fromarray((color_mel*255).astype(np.uint8))
|
||||
|
||||
def inpaint(self, batch, seed, ddim_steps, num_samples=1, W=512, H=512):
|
||||
def inpaint(batch, seed, ddim_steps, num_samples=1, W=512, H=512):
|
||||
model = self.sampler.model
|
||||
|
||||
prng = np.random.RandomState(seed)
|
||||
@@ -437,6 +404,7 @@ class Inpaint:
|
||||
verbose=False)
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
|
||||
|
||||
mask = batch["mask"]# [-1,1]
|
||||
mel = torch.clamp((batch["mel"]+1.0)/2.0,min=0.0, max=1.0)
|
||||
mask = torch.clamp((batch["mask"]+1.0)/2.0,min=0.0, max=1.0)
|
||||
@@ -446,16 +414,14 @@ class Inpaint:
|
||||
inapint_wav = self.vocoder.vocode(inpainted)
|
||||
|
||||
return inpainted, inapint_wav
|
||||
|
||||
def predict(self, input_audio, mel_and_mask, ddim_steps, seed):
|
||||
def predict(input_audio,mel_and_mask,ddim_steps,seed):
|
||||
show_mel = np.array(mel_and_mask['image'].convert("L"))/255 # 由于展示的mel只展示了一部分,所以需要重新从音频生成mel
|
||||
mask = np.array(mel_and_mask["mask"].convert("L"))/255
|
||||
|
||||
mel_bins,mel_len = 80,848
|
||||
|
||||
input_mel = self.gen_mel(input_audio)[:,:mel_len]# 由于展示的mel只展示了一部分,所以需要重新从音频生成mel
|
||||
mask = np.pad(mask, ((0, 0), (0, mel_len - mask.shape[1])), mode='constant',
|
||||
constant_values=0) # 将mask填充到原来的mel的大小
|
||||
mask = np.pad(mask,((0,0),(0,mel_len-mask.shape[1])),mode='constant',constant_values=0)# 将mask填充到原来的mel的大小
|
||||
print(mask.shape,input_mel.shape)
|
||||
with torch.no_grad():
|
||||
batch = make_batch_sd(input_mel,mask,device,num_samples=1)
|
||||
@@ -472,13 +438,11 @@ class Inpaint:
|
||||
gen_wav = (gen_wav * 32768).astype(np.int16)[:input_len]
|
||||
return Image.fromarray((color_mel*255).astype(np.uint8)),(SAMPLE_RATE,gen_wav)
|
||||
|
||||
|
||||
class ASR:
|
||||
def __init__(self, device):
|
||||
print("Initializing Whisper to %s" % device)
|
||||
self.device = device
|
||||
self.model = whisper.load_model("base", device=device)
|
||||
|
||||
def inference(self, audio_path):
|
||||
audio = whisper.load_audio(audio_path)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
@@ -488,6 +452,16 @@ class ASR:
|
||||
result = whisper.decode(self.model, mel, options)
|
||||
return result.text
|
||||
|
||||
class A2T:
|
||||
def __init__(self, device):
|
||||
print("Initializing Audio-To-Text Model to %s" % device)
|
||||
self.device = device
|
||||
self.model = AudioCapModel("audio_to_text/audiocaps_cntrstv_cnn14rnn_trm")
|
||||
def inference(self, audio_path):
|
||||
audio = whisper.load_audio(audio_path)
|
||||
caption_text = self.model(audio)
|
||||
return caption_text[0]
|
||||
|
||||
class ConversationBot:
|
||||
def __init__(self):
|
||||
print("Initializing AudioChatGPT")
|
||||
@@ -498,37 +472,40 @@ class ConversationBot:
|
||||
self.tts = TTS(device="cuda:0")
|
||||
self.t2s = T2S(device="cuda:2")
|
||||
self.i2a = I2A(device="cuda:1")
|
||||
self.a2t = A2T(device="cuda:2")
|
||||
self.asr = ASR(device="cuda:1")
|
||||
self.t2s = T2S(device="cuda:0")
|
||||
self.tts_ood = TTS_OOD(device="cuda:0")
|
||||
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
||||
self.tools = [
|
||||
Tool(name="Generate Image From User Input Text", func=self.t2i.inference,
|
||||
description="useful for when you want to generate an image from a user input text and saved it to a file. like: generate an image of an object or something, or generate an image that includes some objects. "
|
||||
description="useful for when you want to generate an image from a user input text and it saved it to a file. like: generate an image of an object or something, or generate an image that includes some objects. "
|
||||
"The input to this tool should be a string, representing the text used to generate image. "),
|
||||
Tool(name="Get Photo Description", func=self.i2t.inference,
|
||||
description="useful for when you want to know what is inside the photo. receives image_path as input. "
|
||||
"The input to this tool should be a string, representing the image_path. "),
|
||||
Tool(name="Generate Audio From User Input Text", func=self.t2a.inference,
|
||||
description="useful for when you want to generate an audio from a user input text and it saved it to a file. like: generate an audio of something, or generate an audio that includes some objects. "
|
||||
description="useful for when you want to generate an audio from a user input text and it saved it to a file."
|
||||
"The input to this tool should be a string, representing the text used to generate audio."),
|
||||
Tool(
|
||||
name="Generate human speech with style derived from a speech reference and user input text and save it to a file", func= self.tts_ood.inference,
|
||||
description="useful for when you want to generate speech samples with styles (e.g., timbre, emotion, and prosody) derived from a reference custom voice."
|
||||
"ike: Generate a speech with style transferred from this voice. The text is xxx., or speak using the voice of this audio. The text is xxx."
|
||||
"Like: Generate a speech with style transferred from this voice. The text is xxx., or speak using the voice of this audio. The text is xxx."
|
||||
"The input to this tool should be a comma seperated string of two, representing reference audio path and input text."),
|
||||
Tool(name="Generate singing voice From User Input Text, Note and Duration Sequence", func= self.t2s.inference,
|
||||
description="useful for when you want to generate singing voice (Optional: from User Input Text, Note and Duration Sequence) and save it to a file."
|
||||
description="useful for when you want to generate a piece of singing voice (Optional: from User Input Text, Note and Duration Sequence) and save it to a file."
|
||||
"If Like: Generate a piece of singing voice, the input to this tool should be \"\" since there is no User Input Text, Note and Duration Sequence ."
|
||||
"If Like: Generate a piece of singing voice. Text: xxx, Note: xxx, Duration: xxx. "
|
||||
"Or Like: Generate a piece of singing voice. Text is xxx, note is xxx, duration is xxx."
|
||||
"The input to this tool should be a comma seperated string of three, representing text, note and duration sequence since User Input Text, Note and Duration Sequence are all provided."),
|
||||
Tool(name="Synthesize Speech Given the User Input Text", func=self.tts.inference,
|
||||
description="useful for when you want to convert a user input text into speech and saved it to a file."
|
||||
description="useful for when you want to convert a user input text into speech audio it saved it to a file."
|
||||
"The input to this tool should be a string, representing the text used to be converted to speech."),
|
||||
Tool(name="Generate Audio From The Image", func=self.i2a.inference,
|
||||
description="useful for when you want to generate an audio based on an image."
|
||||
"The input to this tool should be a string, representing the image_path. "),
|
||||
Tool(name="Generate Text From The Audio", func=self.a2t.inference,
|
||||
description="useful for when you want to describe an audio in text, receives audio_path as input."
|
||||
"The input to this tool should be a string, representing the audio_path."),
|
||||
Tool(name="Transcribe speech", func=self.asr.inference,
|
||||
description="useful for when you want to know the text corresponding to a human speech, receives audio_path as input."
|
||||
"The input to this tool should be a string, representing the audio_path.")]
|
||||
@@ -547,15 +524,21 @@ class ConversationBot:
|
||||
print("======>Previous memory:\n %s" % self.agent.memory)
|
||||
self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
|
||||
res = self.agent({"input": text})
|
||||
if res['intermediate_steps'] == []:
|
||||
print("======>Current memory:\n %s" % self.agent.memory)
|
||||
response = res['output']
|
||||
state = state + [(text, response)]
|
||||
print("Outputs:", state)
|
||||
return state, state, None
|
||||
else:
|
||||
tool = res['intermediate_steps'][0][0].tool
|
||||
if tool == "Generate Image From User Input Text":
|
||||
if tool == "Generate Image From User Input Text" or tool == "Generate Text From The Audio" or tool == "Transcribe speech":
|
||||
print("======>Current memory:\n %s" % self.agent.memory)
|
||||
response = re.sub('(image/\S*png)', lambda m: f'})*{m.group(0)}*', res['output'])
|
||||
state = state + [(text, response)]
|
||||
print("Outputs:", state)
|
||||
return state, state, None
|
||||
print("======>Current memory:\n %s" % self.agent.memory)
|
||||
audio_filename = res['intermediate_steps'][0][1]
|
||||
response = re.sub('(image/\S*png)', lambda m: f'})*{m.group(0)}*', res['output'])
|
||||
#response = res['output'] + f"<audio src=audio_filename controls=controls></audio>"
|
||||
state = state + [(text, response)]
|
||||
@@ -569,10 +552,9 @@ class ConversationBot:
|
||||
print("Inputs:", file, state)
|
||||
print("======>Previous memory:\n %s" % self.agent.memory)
|
||||
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
||||
print("======>Auto Resize Audio...")
|
||||
audio_load = whisper.load_audio(file.name)
|
||||
soundfile.write(audio_filename, audio_load, samplerate = 16000)
|
||||
description = self.asr.inference(audio_filename)
|
||||
description = self.a2t.inference(audio_filename)
|
||||
Human_prompt = "\nHuman: provide an audio named {}. The description is: {}. This information helps you to understand this audio, but you should use tools to finish following tasks, " \
|
||||
"rather than directly imagine from my description. If you understand, say \"Received\". \n".format(audio_filename, description)
|
||||
AI_prompt = "Received. "
|
||||
@@ -605,6 +587,7 @@ class ConversationBot:
|
||||
print("Outputs:", state)
|
||||
return state, state, txt + ' ' + image_filename + ' ', None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
bot = ConversationBot()
|
||||
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
|
||||
@@ -614,7 +597,7 @@ if __name__ == '__main__':
|
||||
state = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.7):
|
||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image or audio").style(container=False)
|
||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(container=False)
|
||||
with gr.Column(scale=0.15, min_width=0):
|
||||
clear = gr.Button("Clear️")
|
||||
with gr.Column(scale=0.15, min_width=0):
|
||||
@@ -627,4 +610,5 @@ if __name__ == '__main__':
|
||||
clear.click(bot.memory.clear)
|
||||
clear.click(lambda: [], None, chatbot)
|
||||
clear.click(lambda: [], None, state)
|
||||
clear.click(lambda: None, None, outaudio)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
||||
0
audio_to_text/__init__.py
Normal file
0
audio_to_text/__init__.py
Normal file
0
audio_to_text/captioning/__init__.py
Normal file
0
audio_to_text/captioning/__init__.py
Normal file
3
audio_to_text/captioning/models/__init__.py
Normal file
3
audio_to_text/captioning/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base_model import *
|
||||
from .transformer_model import *
|
||||
|
||||
500
audio_to_text/captioning/models/base_model.py
Normal file
500
audio_to_text/captioning/models/base_model.py
Normal file
@@ -0,0 +1,500 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import mean_with_lens, repeat_tensor
|
||||
|
||||
|
||||
class CaptionModel(nn.Module):
|
||||
"""
|
||||
Encoder-decoder captioning model.
|
||||
"""
|
||||
|
||||
pad_idx = 0
|
||||
start_idx = 1
|
||||
end_idx = 2
|
||||
max_length = 20
|
||||
|
||||
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.vocab_size = decoder.vocab_size
|
||||
self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
|
||||
self.inference_forward_keys = ["sample_method", "max_length", "temp"]
|
||||
freeze_encoder = kwargs.get("freeze_encoder", False)
|
||||
if freeze_encoder:
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
self.check_decoder_compatibility()
|
||||
|
||||
def check_decoder_compatibility(self):
|
||||
compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
|
||||
assert isinstance(self.decoder, self.compatible_decoders), \
|
||||
f"{self.decoder.__class__.__name__} is incompatible with " \
|
||||
f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
|
||||
|
||||
@classmethod
|
||||
def set_index(cls, start_idx, end_idx):
|
||||
cls.start_idx = start_idx
|
||||
cls.end_idx = end_idx
|
||||
|
||||
def forward(self, input_dict: Dict):
|
||||
"""
|
||||
input_dict: {
|
||||
(required)
|
||||
mode: train/inference,
|
||||
spec,
|
||||
spec_len,
|
||||
fc,
|
||||
attn,
|
||||
attn_len,
|
||||
[sample_method: greedy],
|
||||
[temp: 1.0] (in case of no teacher forcing)
|
||||
|
||||
(optional, mode=train)
|
||||
cap,
|
||||
cap_len,
|
||||
ss_ratio,
|
||||
|
||||
(optional, mode=inference)
|
||||
sample_method: greedy/beam,
|
||||
max_length,
|
||||
temp,
|
||||
beam_size (optional, sample_method=beam),
|
||||
n_best (optional, sample_method=beam),
|
||||
}
|
||||
"""
|
||||
# encoder_input_keys = ["spec", "spec_len", "fc", "attn", "attn_len"]
|
||||
# encoder_input = { key: input_dict[key] for key in encoder_input_keys }
|
||||
encoder_output_dict = self.encoder(input_dict)
|
||||
if input_dict["mode"] == "train":
|
||||
forward_dict = {
|
||||
"mode": "train", "sample_method": "greedy", "temp": 1.0
|
||||
}
|
||||
for key in self.train_forward_keys:
|
||||
forward_dict[key] = input_dict[key]
|
||||
forward_dict.update(encoder_output_dict)
|
||||
output = self.train_forward(forward_dict)
|
||||
elif input_dict["mode"] == "inference":
|
||||
forward_dict = {"mode": "inference"}
|
||||
default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
|
||||
for key in self.inference_forward_keys:
|
||||
if key in input_dict:
|
||||
forward_dict[key] = input_dict[key]
|
||||
else:
|
||||
forward_dict[key] = default_args[key]
|
||||
|
||||
if forward_dict["sample_method"] == "beam":
|
||||
forward_dict["beam_size"] = input_dict.get("beam_size", 3)
|
||||
forward_dict["n_best"] = input_dict.get("n_best", False)
|
||||
forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
|
||||
elif forward_dict["sample_method"] == "dbs":
|
||||
forward_dict["beam_size"] = input_dict.get("beam_size", 6)
|
||||
forward_dict["group_size"] = input_dict.get("group_size", 3)
|
||||
forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
|
||||
forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
|
||||
|
||||
forward_dict.update(encoder_output_dict)
|
||||
output = self.inference_forward(forward_dict)
|
||||
else:
|
||||
raise Exception("mode should be either 'train' or 'inference'")
|
||||
|
||||
return output
|
||||
|
||||
def prepare_output(self, input_dict):
|
||||
output = {}
|
||||
batch_size = input_dict["fc_emb"].size(0)
|
||||
if input_dict["mode"] == "train":
|
||||
max_length = input_dict["cap"].size(1) - 1
|
||||
elif input_dict["mode"] == "inference":
|
||||
max_length = input_dict["max_length"]
|
||||
else:
|
||||
raise Exception("mode should be either 'train' or 'inference'")
|
||||
device = input_dict["fc_emb"].device
|
||||
output["seq"] = torch.full((batch_size, max_length), self.end_idx,
|
||||
dtype=torch.long)
|
||||
output["logit"] = torch.empty(batch_size, max_length,
|
||||
self.vocab_size).to(device)
|
||||
output["sampled_logprob"] = torch.zeros(batch_size, max_length)
|
||||
output["embed"] = torch.empty(batch_size, max_length,
|
||||
self.decoder.d_model).to(device)
|
||||
return output
|
||||
|
||||
def train_forward(self, input_dict):
|
||||
if input_dict["ss_ratio"] != 1: # scheduled sampling training
|
||||
input_dict["mode"] = "train"
|
||||
return self.stepwise_forward(input_dict)
|
||||
output = self.seq_forward(input_dict)
|
||||
self.train_process(output, input_dict)
|
||||
return output
|
||||
|
||||
def seq_forward(self, input_dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def train_process(self, output, input_dict):
|
||||
pass
|
||||
|
||||
def inference_forward(self, input_dict):
|
||||
if input_dict["sample_method"] == "beam":
|
||||
return self.beam_search(input_dict)
|
||||
elif input_dict["sample_method"] == "dbs":
|
||||
return self.diverse_beam_search(input_dict)
|
||||
return self.stepwise_forward(input_dict)
|
||||
|
||||
def stepwise_forward(self, input_dict):
|
||||
"""Step-by-step decoding"""
|
||||
output = self.prepare_output(input_dict)
|
||||
max_length = output["seq"].size(1)
|
||||
# start sampling
|
||||
for t in range(max_length):
|
||||
input_dict["t"] = t
|
||||
self.decode_step(input_dict, output)
|
||||
if input_dict["mode"] == "inference": # decide whether to stop when sampling
|
||||
unfinished_t = output["seq"][:, t] != self.end_idx
|
||||
if t == 0:
|
||||
unfinished = unfinished_t
|
||||
else:
|
||||
unfinished *= unfinished_t
|
||||
output["seq"][:, t][~unfinished] = self.end_idx
|
||||
if unfinished.sum() == 0:
|
||||
break
|
||||
self.stepwise_process(output)
|
||||
return output
|
||||
|
||||
def decode_step(self, input_dict, output):
|
||||
"""Decoding operation of timestep t"""
|
||||
decoder_input = self.prepare_decoder_input(input_dict, output)
|
||||
# feed to the decoder to get logit
|
||||
output_t = self.decoder(decoder_input)
|
||||
logit_t = output_t["logit"]
|
||||
# assert logit_t.ndim == 3
|
||||
if logit_t.size(1) == 1:
|
||||
logit_t = logit_t.squeeze(1)
|
||||
embed_t = output_t["embed"].squeeze(1)
|
||||
elif logit_t.size(1) > 1:
|
||||
logit_t = logit_t[:, -1, :]
|
||||
embed_t = output_t["embed"][:, -1, :]
|
||||
else:
|
||||
raise Exception("no logit output")
|
||||
# sample the next input word and get the corresponding logit
|
||||
sampled = self.sample_next_word(logit_t,
|
||||
method=input_dict["sample_method"],
|
||||
temp=input_dict["temp"])
|
||||
|
||||
output_t.update(sampled)
|
||||
output_t["t"] = input_dict["t"]
|
||||
output_t["logit"] = logit_t
|
||||
output_t["embed"] = embed_t
|
||||
self.stepwise_process_step(output, output_t)
|
||||
|
||||
def prepare_decoder_input(self, input_dict, output):
|
||||
"""Prepare the inp ut dict for the decoder"""
|
||||
raise NotImplementedError
|
||||
|
||||
def stepwise_process_step(self, output, output_t):
|
||||
"""Postprocessing (save output values) after each timestep t"""
|
||||
t = output_t["t"]
|
||||
output["logit"][:, t, :] = output_t["logit"]
|
||||
output["seq"][:, t] = output_t["word"]
|
||||
output["sampled_logprob"][:, t] = output_t["probs"]
|
||||
output["embed"][:, t, :] = output_t["embed"]
|
||||
|
||||
def stepwise_process(self, output):
|
||||
"""Postprocessing after the whole step-by-step autoregressive decoding"""
|
||||
pass
|
||||
|
||||
def sample_next_word(self, logit, method, temp):
|
||||
"""Sample the next word, given probs output by the decoder"""
|
||||
logprob = torch.log_softmax(logit, dim=1)
|
||||
if method == "greedy":
|
||||
sampled_logprob, word = torch.max(logprob.detach(), 1)
|
||||
elif method == "gumbel":
|
||||
def sample_gumbel(shape, eps=1e-20):
|
||||
U = torch.rand(shape).to(logprob.device)
|
||||
return -torch.log(-torch.log(U + eps) + eps)
|
||||
def gumbel_softmax_sample(logit, temperature):
|
||||
y = logit + sample_gumbel(logit.size())
|
||||
return torch.log_softmax(y / temperature, dim=-1)
|
||||
_logprob = gumbel_softmax_sample(logprob, temp)
|
||||
_, word = torch.max(_logprob.data, 1)
|
||||
sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
|
||||
else:
|
||||
logprob = logprob / temp
|
||||
if method.startswith("top"):
|
||||
top_num = float(method[3:])
|
||||
if 0 < top_num < 1: # top-p sampling
|
||||
probs = torch.softmax(logit, dim=1)
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
|
||||
_cumsum = sorted_probs.cumsum(1)
|
||||
mask = _cumsum < top_num
|
||||
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
|
||||
sorted_probs = sorted_probs * mask.to(sorted_probs)
|
||||
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
|
||||
logprob.scatter_(1, sorted_indices, sorted_probs.log())
|
||||
else: # top-k sampling
|
||||
k = int(top_num)
|
||||
tmp = torch.empty_like(logprob).fill_(float('-inf'))
|
||||
topk, indices = torch.topk(logprob, k, dim=1)
|
||||
tmp = tmp.scatter(1, indices, topk)
|
||||
logprob = tmp
|
||||
word = torch.distributions.Categorical(logits=logprob.detach()).sample()
|
||||
sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
|
||||
word = word.detach().long()
|
||||
# sampled_logprob: [N,], word: [N,]
|
||||
return {"word": word, "probs": sampled_logprob}
|
||||
|
||||
def beam_search(self, input_dict):
|
||||
output = self.prepare_output(input_dict)
|
||||
max_length = input_dict["max_length"]
|
||||
beam_size = input_dict["beam_size"]
|
||||
if input_dict["n_best"]:
|
||||
n_best_size = input_dict["n_best_size"]
|
||||
batch_size, max_length = output["seq"].size()
|
||||
output["seq"] = torch.full((batch_size, n_best_size, max_length),
|
||||
self.end_idx, dtype=torch.long)
|
||||
|
||||
temp = input_dict["temp"]
|
||||
# instance by instance beam seach
|
||||
for i in range(output["seq"].size(0)):
|
||||
output_i = self.prepare_beamsearch_output(input_dict)
|
||||
input_dict["sample_idx"] = i
|
||||
for t in range(max_length):
|
||||
input_dict["t"] = t
|
||||
output_t = self.beamsearch_step(input_dict, output_i)
|
||||
#######################################
|
||||
# merge with previous beam and select the current max prob beam
|
||||
#######################################
|
||||
logit_t = output_t["logit"]
|
||||
if logit_t.size(1) == 1:
|
||||
logit_t = logit_t.squeeze(1)
|
||||
elif logit_t.size(1) > 1:
|
||||
logit_t = logit_t[:, -1, :]
|
||||
else:
|
||||
raise Exception("no logit output")
|
||||
logprob_t = torch.log_softmax(logit_t, dim=1)
|
||||
logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
|
||||
logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
|
||||
if t == 0: # for the first step, all k seq will have the same probs
|
||||
topk_logprob, topk_words = logprob_t[0].topk(
|
||||
beam_size, 0, True, True)
|
||||
else: # unroll and find top logprob, and their unrolled indices
|
||||
topk_logprob, topk_words = logprob_t.view(-1).topk(
|
||||
beam_size, 0, True, True)
|
||||
topk_words = topk_words.cpu()
|
||||
output_i["topk_logprob"] = topk_logprob
|
||||
# output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
|
||||
output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
|
||||
rounding_mode='trunc')
|
||||
output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
|
||||
if t == 0:
|
||||
output_i["seq"] = output_i["next_word"].unsqueeze(1)
|
||||
else:
|
||||
output_i["seq"] = torch.cat([
|
||||
output_i["seq"][output_i["prev_words_beam"]],
|
||||
output_i["next_word"].unsqueeze(1)], dim=1)
|
||||
|
||||
# add finished beams to results
|
||||
is_end = output_i["next_word"] == self.end_idx
|
||||
if t == max_length - 1:
|
||||
is_end.fill_(1)
|
||||
|
||||
for beam_idx in range(beam_size):
|
||||
if is_end[beam_idx]:
|
||||
final_beam = {
|
||||
"seq": output_i["seq"][beam_idx].clone(),
|
||||
"score": output_i["topk_logprob"][beam_idx].item()
|
||||
}
|
||||
final_beam["score"] = final_beam["score"] / (t + 1)
|
||||
output_i["done_beams"].append(final_beam)
|
||||
output_i["topk_logprob"][is_end] -= 1000
|
||||
|
||||
self.beamsearch_process_step(output_i, output_t)
|
||||
|
||||
self.beamsearch_process(output, output_i, input_dict)
|
||||
return output
|
||||
|
||||
def prepare_beamsearch_output(self, input_dict):
|
||||
beam_size = input_dict["beam_size"]
|
||||
device = input_dict["fc_emb"].device
|
||||
output = {
|
||||
"topk_logprob": torch.zeros(beam_size).to(device),
|
||||
"seq": None,
|
||||
"prev_words_beam": None,
|
||||
"next_word": None,
|
||||
"done_beams": [],
|
||||
}
|
||||
return output
|
||||
|
||||
def beamsearch_step(self, input_dict, output_i):
|
||||
decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
|
||||
output_t = self.decoder(decoder_input)
|
||||
output_t["t"] = input_dict["t"]
|
||||
return output_t
|
||||
|
||||
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
||||
raise NotImplementedError
|
||||
|
||||
def beamsearch_process_step(self, output_i, output_t):
|
||||
pass
|
||||
|
||||
def beamsearch_process(self, output, output_i, input_dict):
|
||||
i = input_dict["sample_idx"]
|
||||
done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
|
||||
if input_dict["n_best"]:
|
||||
done_beams = done_beams[:input_dict["n_best_size"]]
|
||||
for out_idx, done_beam in enumerate(done_beams):
|
||||
seq = done_beam["seq"]
|
||||
output["seq"][i][out_idx, :len(seq)] = seq
|
||||
else:
|
||||
seq = done_beams[0]["seq"]
|
||||
output["seq"][i][:len(seq)] = seq
|
||||
|
||||
def diverse_beam_search(self, input_dict):
|
||||
|
||||
def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
|
||||
local_time = t - divm
|
||||
unaug_logprob = logprob.clone()
|
||||
|
||||
if divm > 0:
|
||||
change = torch.zeros(logprob.size(-1))
|
||||
for prev_choice in range(divm):
|
||||
prev_decisions = seq_table[prev_choice][..., local_time]
|
||||
for prev_labels in range(bdash):
|
||||
change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
|
||||
|
||||
change = change.to(logprob.device)
|
||||
logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
|
||||
|
||||
return logprob, unaug_logprob
|
||||
|
||||
output = self.prepare_output(input_dict)
|
||||
group_size = input_dict["group_size"]
|
||||
batch_size = output["seq"].size(0)
|
||||
beam_size = input_dict["beam_size"]
|
||||
bdash = beam_size // group_size
|
||||
input_dict["bdash"] = bdash
|
||||
diversity_lambda = input_dict["diversity_lambda"]
|
||||
device = input_dict["fc_emb"].device
|
||||
max_length = input_dict["max_length"]
|
||||
temp = input_dict["temp"]
|
||||
group_nbest = input_dict["group_nbest"]
|
||||
batch_size, max_length = output["seq"].size()
|
||||
if group_nbest:
|
||||
output["seq"] = torch.full((batch_size, beam_size, max_length),
|
||||
self.end_idx, dtype=torch.long)
|
||||
else:
|
||||
output["seq"] = torch.full((batch_size, group_size, max_length),
|
||||
self.end_idx, dtype=torch.long)
|
||||
|
||||
|
||||
for i in range(batch_size):
|
||||
input_dict["sample_idx"] = i
|
||||
seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
|
||||
logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
|
||||
done_beams_table = [[] for _ in range(group_size)]
|
||||
|
||||
output_i = {
|
||||
"prev_words_beam": [None for _ in range(group_size)],
|
||||
"next_word": [None for _ in range(group_size)],
|
||||
"state": [None for _ in range(group_size)]
|
||||
}
|
||||
|
||||
for t in range(max_length + group_size - 1):
|
||||
input_dict["t"] = t
|
||||
for divm in range(group_size):
|
||||
input_dict["divm"] = divm
|
||||
if t >= divm and t <= max_length + divm - 1:
|
||||
local_time = t - divm
|
||||
decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
|
||||
output_t = self.decoder(decoder_input)
|
||||
output_t["divm"] = divm
|
||||
logit_t = output_t["logit"]
|
||||
if logit_t.size(1) == 1:
|
||||
logit_t = logit_t.squeeze(1)
|
||||
elif logit_t.size(1) > 1:
|
||||
logit_t = logit_t[:, -1, :]
|
||||
else:
|
||||
raise Exception("no logit output")
|
||||
logprob_t = torch.log_softmax(logit_t, dim=1)
|
||||
logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
|
||||
logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
|
||||
logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
|
||||
if local_time == 0: # for the first step, all k seq will have the same probs
|
||||
topk_logprob, topk_words = logprob_t[0].topk(
|
||||
bdash, 0, True, True)
|
||||
else: # unroll and find top logprob, and their unrolled indices
|
||||
topk_logprob, topk_words = logprob_t.view(-1).topk(
|
||||
bdash, 0, True, True)
|
||||
topk_words = topk_words.cpu()
|
||||
logprob_table[divm] = topk_logprob
|
||||
output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
|
||||
output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
|
||||
if local_time > 0:
|
||||
seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
|
||||
seq_table[divm] = torch.cat([
|
||||
seq_table[divm],
|
||||
output_i["next_word"][divm].unsqueeze(-1)], -1)
|
||||
|
||||
is_end = seq_table[divm][:, t-divm] == self.end_idx
|
||||
assert seq_table[divm].shape[-1] == t - divm + 1
|
||||
if t == max_length + divm - 1:
|
||||
is_end.fill_(1)
|
||||
for beam_idx in range(bdash):
|
||||
if is_end[beam_idx]:
|
||||
final_beam = {
|
||||
"seq": seq_table[divm][beam_idx].clone(),
|
||||
"score": logprob_table[divm][beam_idx].item()
|
||||
}
|
||||
final_beam["score"] = final_beam["score"] / (t - divm + 1)
|
||||
done_beams_table[divm].append(final_beam)
|
||||
logprob_table[divm][is_end] -= 1000
|
||||
self.dbs_process_step(output_i, output_t)
|
||||
done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
|
||||
if group_nbest:
|
||||
done_beams = sum(done_beams_table, [])
|
||||
else:
|
||||
done_beams = [group_beam[0] for group_beam in done_beams_table]
|
||||
for _, done_beam in enumerate(done_beams):
|
||||
output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
|
||||
|
||||
return output
|
||||
|
||||
def prepare_dbs_decoder_input(self, input_dict, output_i):
|
||||
raise NotImplementedError
|
||||
|
||||
def dbs_process_step(self, output_i, output_t):
|
||||
pass
|
||||
|
||||
|
||||
class CaptionSequenceModel(nn.Module):
|
||||
|
||||
def __init__(self, model, seq_output_size):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
if model.decoder.d_model != seq_output_size:
|
||||
self.output_transform = nn.Linear(model.decoder.d_model, seq_output_size)
|
||||
else:
|
||||
self.output_transform = lambda x: x
|
||||
|
||||
def forward(self, input_dict):
|
||||
output = self.model(input_dict)
|
||||
|
||||
if input_dict["mode"] == "train":
|
||||
lens = input_dict["cap_len"] - 1
|
||||
# seq_outputs: [N, d_model]
|
||||
elif input_dict["mode"] == "inference":
|
||||
if "sample_method" in input_dict and input_dict["sample_method"] == "beam":
|
||||
return output
|
||||
seq = output["seq"]
|
||||
lens = torch.where(seq == self.model.end_idx, torch.zeros_like(seq), torch.ones_like(seq)).sum(dim=1)
|
||||
else:
|
||||
raise Exception("mode should be either 'train' or 'inference'")
|
||||
seq_output = mean_with_lens(output["embed"], lens)
|
||||
seq_output = self.output_transform(seq_output)
|
||||
output["seq_output"] = seq_output
|
||||
return output
|
||||
|
||||
746
audio_to_text/captioning/models/decoder.py
Normal file
746
audio_to_text/captioning/models/decoder.py
Normal file
@@ -0,0 +1,746 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import generate_length_mask, init, PositionalEncoding
|
||||
|
||||
|
||||
class BaseDecoder(nn.Module):
|
||||
"""
|
||||
Take word/audio embeddings and output the next word probs
|
||||
Base decoder, cannot be called directly
|
||||
All decoders should inherit from this class
|
||||
"""
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim,
|
||||
attn_emb_dim, dropout=0.2):
|
||||
super().__init__()
|
||||
self.emb_dim = emb_dim
|
||||
self.vocab_size = vocab_size
|
||||
self.fc_emb_dim = fc_emb_dim
|
||||
self.attn_emb_dim = attn_emb_dim
|
||||
self.word_embedding = nn.Embedding(vocab_size, emb_dim)
|
||||
self.in_dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_word_embedding(self, weight, freeze=True):
|
||||
embedding = np.load(weight)
|
||||
assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
|
||||
assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
|
||||
|
||||
# embeddings = torch.as_tensor(embeddings).float()
|
||||
# self.word_embeddings.weight = nn.Parameter(embeddings)
|
||||
# for para in self.word_embeddings.parameters():
|
||||
# para.requires_grad = tune
|
||||
self.word_embedding = nn.Embedding.from_pretrained(embedding,
|
||||
freeze=freeze)
|
||||
|
||||
|
||||
class RnnDecoder(BaseDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs):
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout,)
|
||||
self.d_model = d_model
|
||||
self.num_layers = kwargs.get('num_layers', 1)
|
||||
self.bidirectional = kwargs.get('bidirectional', False)
|
||||
self.rnn_type = kwargs.get('rnn_type', "GRU")
|
||||
self.classifier = nn.Linear(
|
||||
self.d_model * (self.bidirectional + 1), vocab_size)
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_hidden(self, bs, device):
|
||||
num_dire = self.bidirectional + 1
|
||||
n_layer = self.num_layers
|
||||
hid_dim = self.d_model
|
||||
if self.rnn_type == "LSTM":
|
||||
return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device),
|
||||
torch.zeros(num_dire * n_layer, bs, hid_dim).to(device))
|
||||
else:
|
||||
return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)
|
||||
|
||||
|
||||
class RnnFcDecoder(RnnDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs):
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs)
|
||||
self.model = getattr(nn, self.rnn_type)(
|
||||
input_size=self.emb_dim * 2,
|
||||
hidden_size=self.d_model,
|
||||
batch_first=True,
|
||||
num_layers=self.num_layers,
|
||||
bidirectional=self.bidirectional)
|
||||
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"]
|
||||
state = input_dict.get("state", None)
|
||||
fc_emb = input_dict["fc_emb"]
|
||||
|
||||
word = word.to(fc_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word))
|
||||
|
||||
p_fc_emb = self.fc_proj(fc_emb)
|
||||
# embed: [N, T, embed_size]
|
||||
embed = torch.cat((embed, p_fc_emb), dim=-1)
|
||||
|
||||
out, state = self.model(embed, state)
|
||||
# out: [N, T, hs], states: [num_layers * num_dire, N, hs]
|
||||
logits = self.classifier(out)
|
||||
output = {
|
||||
"state": state,
|
||||
"embeds": out,
|
||||
"logits": logits
|
||||
}
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Seq2SeqAttention(nn.Module):
|
||||
|
||||
def __init__(self, hs_enc, hs_dec, attn_size):
|
||||
"""
|
||||
Args:
|
||||
hs_enc: encoder hidden size
|
||||
hs_dec: decoder hidden size
|
||||
attn_size: attention vector size
|
||||
"""
|
||||
super(Seq2SeqAttention, self).__init__()
|
||||
self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size)
|
||||
self.v = nn.Parameter(torch.randn(attn_size))
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, h_dec, h_enc, src_lens):
|
||||
"""
|
||||
Args:
|
||||
h_dec: decoder hidden (query), [N, hs_dec]
|
||||
h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
|
||||
src_lens: source (encoder memory) lengths, [N, ]
|
||||
"""
|
||||
N = h_enc.size(0)
|
||||
src_max_len = h_enc.size(1)
|
||||
h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
|
||||
|
||||
attn_input = torch.cat((h_dec, h_enc), dim=-1)
|
||||
attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
|
||||
|
||||
v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
|
||||
score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
|
||||
|
||||
idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
|
||||
mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
|
||||
|
||||
score = score.masked_fill(mask == 0, -1e10)
|
||||
weights = torch.softmax(score, dim=-1) # [N, src_max_len]
|
||||
ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
|
||||
|
||||
return ctx, weights
|
||||
|
||||
|
||||
class AttentionProj(nn.Module):
|
||||
|
||||
def __init__(self, hs_enc, hs_dec, embed_dim, attn_size):
|
||||
self.q_proj = nn.Linear(hs_dec, embed_dim)
|
||||
self.kv_proj = nn.Linear(hs_enc, embed_dim)
|
||||
self.h2attn = nn.Linear(embed_dim * 2, attn_size)
|
||||
self.v = nn.Parameter(torch.randn(attn_size))
|
||||
self.apply(init)
|
||||
|
||||
def init(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, h_dec, h_enc, src_lens):
|
||||
"""
|
||||
Args:
|
||||
h_dec: decoder hidden (query), [N, hs_dec]
|
||||
h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
|
||||
src_lens: source (encoder memory) lengths, [N, ]
|
||||
"""
|
||||
h_enc = self.kv_proj(h_enc) # [N, src_max_len, embed_dim]
|
||||
h_dec = self.q_proj(h_dec) # [N, embed_dim]
|
||||
N = h_enc.size(0)
|
||||
src_max_len = h_enc.size(1)
|
||||
h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
|
||||
|
||||
attn_input = torch.cat((h_dec, h_enc), dim=-1)
|
||||
attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
|
||||
|
||||
v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
|
||||
score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
|
||||
|
||||
idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
|
||||
mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
|
||||
|
||||
score = score.masked_fill(mask == 0, -1e10)
|
||||
weights = torch.softmax(score, dim=-1) # [N, src_max_len]
|
||||
ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
|
||||
|
||||
return ctx, weights
|
||||
|
||||
|
||||
class BahAttnDecoder(RnnDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs):
|
||||
"""
|
||||
concatenate fc, attn, word to feed to the rnn
|
||||
"""
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs)
|
||||
attn_size = kwargs.get("attn_size", self.d_model)
|
||||
self.model = getattr(nn, self.rnn_type)(
|
||||
input_size=self.emb_dim * 3,
|
||||
hidden_size=self.d_model,
|
||||
batch_first=True,
|
||||
num_layers=self.num_layers,
|
||||
bidirectional=self.bidirectional)
|
||||
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
||||
self.d_model * (self.bidirectional + 1) * \
|
||||
self.num_layers,
|
||||
attn_size)
|
||||
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
|
||||
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"]
|
||||
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
||||
fc_emb = input_dict["fc_emb"]
|
||||
attn_emb = input_dict["attn_emb"]
|
||||
attn_emb_len = input_dict["attn_emb_len"]
|
||||
|
||||
word = word.to(fc_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word))
|
||||
|
||||
# embed: [N, 1, embed_size]
|
||||
if state is None:
|
||||
state = self.init_hidden(word.size(0), fc_emb.device)
|
||||
if self.rnn_type == "LSTM":
|
||||
query = state[0].transpose(0, 1).flatten(1)
|
||||
else:
|
||||
query = state.transpose(0, 1).flatten(1)
|
||||
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
||||
|
||||
p_fc_emb = self.fc_proj(fc_emb)
|
||||
p_ctx = self.ctx_proj(c)
|
||||
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)),
|
||||
dim=-1)
|
||||
|
||||
out, state = self.model(rnn_input, state)
|
||||
|
||||
output = {
|
||||
"state": state,
|
||||
"embed": out,
|
||||
"logit": self.classifier(out),
|
||||
"attn_weight": attn_weight
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class BahAttnDecoder2(RnnDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs):
|
||||
"""
|
||||
add fc, attn, word together to feed to the rnn
|
||||
"""
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs)
|
||||
attn_size = kwargs.get("attn_size", self.d_model)
|
||||
self.model = getattr(nn, self.rnn_type)(
|
||||
input_size=self.emb_dim,
|
||||
hidden_size=self.d_model,
|
||||
batch_first=True,
|
||||
num_layers=self.num_layers,
|
||||
bidirectional=self.bidirectional)
|
||||
self.attn = Seq2SeqAttention(self.emb_dim,
|
||||
self.d_model * (self.bidirectional + 1) * \
|
||||
self.num_layers,
|
||||
attn_size)
|
||||
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
|
||||
self.attn_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
|
||||
self.apply(partial(init, method="xavier"))
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"]
|
||||
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
||||
fc_emb = input_dict["fc_emb"]
|
||||
attn_emb = input_dict["attn_emb"]
|
||||
attn_emb_len = input_dict["attn_emb_len"]
|
||||
|
||||
word = word.to(fc_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word))
|
||||
p_attn_emb = self.attn_proj(attn_emb)
|
||||
|
||||
# embed: [N, 1, embed_size]
|
||||
if state is None:
|
||||
state = self.init_hidden(word.size(0), fc_emb.device)
|
||||
if self.rnn_type == "LSTM":
|
||||
query = state[0].transpose(0, 1).flatten(1)
|
||||
else:
|
||||
query = state.transpose(0, 1).flatten(1)
|
||||
c, attn_weight = self.attn(query, p_attn_emb, attn_emb_len)
|
||||
|
||||
p_fc_emb = self.fc_proj(fc_emb)
|
||||
rnn_input = embed + c.unsqueeze(1) + p_fc_emb.unsqueeze(1)
|
||||
|
||||
out, state = self.model(rnn_input, state)
|
||||
|
||||
output = {
|
||||
"state": state,
|
||||
"embed": out,
|
||||
"logit": self.classifier(out),
|
||||
"attn_weight": attn_weight
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class ConditionalBahAttnDecoder(RnnDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs):
|
||||
"""
|
||||
concatenate fc, attn, word to feed to the rnn
|
||||
"""
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs)
|
||||
attn_size = kwargs.get("attn_size", self.d_model)
|
||||
self.model = getattr(nn, self.rnn_type)(
|
||||
input_size=self.emb_dim * 3,
|
||||
hidden_size=self.d_model,
|
||||
batch_first=True,
|
||||
num_layers=self.num_layers,
|
||||
bidirectional=self.bidirectional)
|
||||
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
||||
self.d_model * (self.bidirectional + 1) * \
|
||||
self.num_layers,
|
||||
attn_size)
|
||||
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
|
||||
self.condition_embedding = nn.Embedding(2, emb_dim)
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"]
|
||||
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
||||
fc_emb = input_dict["fc_emb"]
|
||||
attn_emb = input_dict["attn_emb"]
|
||||
attn_emb_len = input_dict["attn_emb_len"]
|
||||
condition = input_dict["condition"]
|
||||
|
||||
word = word.to(fc_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word))
|
||||
|
||||
condition = torch.as_tensor([[1 - c, c] for c in condition]).to(fc_emb.device)
|
||||
condition_emb = torch.matmul(condition, self.condition_embedding.weight)
|
||||
# condition_embs: [N, emb_dim]
|
||||
|
||||
# embed: [N, 1, embed_size]
|
||||
if state is None:
|
||||
state = self.init_hidden(word.size(0), fc_emb.device)
|
||||
if self.rnn_type == "LSTM":
|
||||
query = state[0].transpose(0, 1).flatten(1)
|
||||
else:
|
||||
query = state.transpose(0, 1).flatten(1)
|
||||
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
||||
|
||||
p_ctx = self.ctx_proj(c)
|
||||
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), condition_emb.unsqueeze(1)),
|
||||
dim=-1)
|
||||
|
||||
out, state = self.model(rnn_input, state)
|
||||
|
||||
output = {
|
||||
"state": state,
|
||||
"embed": out,
|
||||
"logit": self.classifier(out),
|
||||
"attn_weight": attn_weight
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class StructBahAttnDecoder(RnnDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, struct_vocab_size,
|
||||
attn_emb_dim, dropout, d_model, **kwargs):
|
||||
"""
|
||||
concatenate fc, attn, word to feed to the rnn
|
||||
"""
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs)
|
||||
attn_size = kwargs.get("attn_size", self.d_model)
|
||||
self.model = getattr(nn, self.rnn_type)(
|
||||
input_size=self.emb_dim * 3,
|
||||
hidden_size=self.d_model,
|
||||
batch_first=True,
|
||||
num_layers=self.num_layers,
|
||||
bidirectional=self.bidirectional)
|
||||
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
||||
self.d_model * (self.bidirectional + 1) * \
|
||||
self.num_layers,
|
||||
attn_size)
|
||||
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
|
||||
self.struct_embedding = nn.Embedding(struct_vocab_size, emb_dim)
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"]
|
||||
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
||||
fc_emb = input_dict["fc_emb"]
|
||||
attn_emb = input_dict["attn_emb"]
|
||||
attn_emb_len = input_dict["attn_emb_len"]
|
||||
structure = input_dict["structure"]
|
||||
|
||||
word = word.to(fc_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word))
|
||||
|
||||
struct_emb = self.struct_embedding(structure)
|
||||
# struct_embs: [N, emb_dim]
|
||||
|
||||
# embed: [N, 1, embed_size]
|
||||
if state is None:
|
||||
state = self.init_hidden(word.size(0), fc_emb.device)
|
||||
if self.rnn_type == "LSTM":
|
||||
query = state[0].transpose(0, 1).flatten(1)
|
||||
else:
|
||||
query = state.transpose(0, 1).flatten(1)
|
||||
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
||||
|
||||
p_ctx = self.ctx_proj(c)
|
||||
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), struct_emb.unsqueeze(1)), dim=-1)
|
||||
|
||||
out, state = self.model(rnn_input, state)
|
||||
|
||||
output = {
|
||||
"state": state,
|
||||
"embed": out,
|
||||
"logit": self.classifier(out),
|
||||
"attn_weight": attn_weight
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class StyleBahAttnDecoder(RnnDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs):
|
||||
"""
|
||||
concatenate fc, attn, word to feed to the rnn
|
||||
"""
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs)
|
||||
attn_size = kwargs.get("attn_size", self.d_model)
|
||||
self.model = getattr(nn, self.rnn_type)(
|
||||
input_size=self.emb_dim * 3,
|
||||
hidden_size=self.d_model,
|
||||
batch_first=True,
|
||||
num_layers=self.num_layers,
|
||||
bidirectional=self.bidirectional)
|
||||
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
||||
self.d_model * (self.bidirectional + 1) * \
|
||||
self.num_layers,
|
||||
attn_size)
|
||||
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"]
|
||||
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
||||
fc_emb = input_dict["fc_emb"]
|
||||
attn_emb = input_dict["attn_emb"]
|
||||
attn_emb_len = input_dict["attn_emb_len"]
|
||||
style = input_dict["style"]
|
||||
|
||||
word = word.to(fc_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word))
|
||||
|
||||
# embed: [N, 1, embed_size]
|
||||
if state is None:
|
||||
state = self.init_hidden(word.size(0), fc_emb.device)
|
||||
if self.rnn_type == "LSTM":
|
||||
query = state[0].transpose(0, 1).flatten(1)
|
||||
else:
|
||||
query = state.transpose(0, 1).flatten(1)
|
||||
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
||||
|
||||
p_ctx = self.ctx_proj(c)
|
||||
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), style.unsqueeze(1)),
|
||||
dim=-1)
|
||||
|
||||
out, state = self.model(rnn_input, state)
|
||||
|
||||
output = {
|
||||
"state": state,
|
||||
"embed": out,
|
||||
"logit": self.classifier(out),
|
||||
"attn_weight": attn_weight
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class BahAttnDecoder3(RnnDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs):
|
||||
"""
|
||||
concatenate fc, attn, word to feed to the rnn
|
||||
"""
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs)
|
||||
attn_size = kwargs.get("attn_size", self.d_model)
|
||||
self.model = getattr(nn, self.rnn_type)(
|
||||
input_size=self.emb_dim + attn_emb_dim,
|
||||
hidden_size=self.d_model,
|
||||
batch_first=True,
|
||||
num_layers=self.num_layers,
|
||||
bidirectional=self.bidirectional)
|
||||
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
||||
self.d_model * (self.bidirectional + 1) * \
|
||||
self.num_layers,
|
||||
attn_size)
|
||||
self.ctx_proj = lambda x: x
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"]
|
||||
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
||||
fc_emb = input_dict["fc_emb"]
|
||||
attn_emb = input_dict["attn_emb"]
|
||||
attn_emb_len = input_dict["attn_emb_len"]
|
||||
|
||||
if word.size(-1) == self.fc_emb_dim: # fc_emb
|
||||
embed = word.unsqueeze(1)
|
||||
elif word.size(-1) == 1: # word
|
||||
word = word.to(fc_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word))
|
||||
else:
|
||||
raise Exception(f"problem with word input size {word.size()}")
|
||||
|
||||
# embed: [N, 1, embed_size]
|
||||
if state is None:
|
||||
state = self.init_hidden(word.size(0), fc_emb.device)
|
||||
if self.rnn_type == "LSTM":
|
||||
query = state[0].transpose(0, 1).flatten(1)
|
||||
else:
|
||||
query = state.transpose(0, 1).flatten(1)
|
||||
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
||||
|
||||
p_ctx = self.ctx_proj(c)
|
||||
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1)), dim=-1)
|
||||
|
||||
out, state = self.model(rnn_input, state)
|
||||
|
||||
output = {
|
||||
"state": state,
|
||||
"embed": out,
|
||||
"logit": self.classifier(out),
|
||||
"attn_weight": attn_weight
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class SpecificityBahAttnDecoder(RnnDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs):
|
||||
"""
|
||||
concatenate fc, attn, word to feed to the rnn
|
||||
"""
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, d_model, **kwargs)
|
||||
attn_size = kwargs.get("attn_size", self.d_model)
|
||||
self.model = getattr(nn, self.rnn_type)(
|
||||
input_size=self.emb_dim + attn_emb_dim + 1,
|
||||
hidden_size=self.d_model,
|
||||
batch_first=True,
|
||||
num_layers=self.num_layers,
|
||||
bidirectional=self.bidirectional)
|
||||
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
||||
self.d_model * (self.bidirectional + 1) * \
|
||||
self.num_layers,
|
||||
attn_size)
|
||||
self.ctx_proj = lambda x: x
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"]
|
||||
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
||||
fc_emb = input_dict["fc_emb"]
|
||||
attn_emb = input_dict["attn_emb"]
|
||||
attn_emb_len = input_dict["attn_emb_len"]
|
||||
condition = input_dict["condition"] # [N,]
|
||||
|
||||
word = word.to(fc_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word))
|
||||
|
||||
# embed: [N, 1, embed_size]
|
||||
if state is None:
|
||||
state = self.init_hidden(word.size(0), fc_emb.device)
|
||||
if self.rnn_type == "LSTM":
|
||||
query = state[0].transpose(0, 1).flatten(1)
|
||||
else:
|
||||
query = state.transpose(0, 1).flatten(1)
|
||||
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
||||
|
||||
p_ctx = self.ctx_proj(c)
|
||||
rnn_input = torch.cat(
|
||||
(embed, p_ctx.unsqueeze(1), condition.reshape(-1, 1, 1)),
|
||||
dim=-1)
|
||||
|
||||
out, state = self.model(rnn_input, state)
|
||||
|
||||
output = {
|
||||
"state": state,
|
||||
"embed": out,
|
||||
"logit": self.classifier(out),
|
||||
"attn_weight": attn_weight
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(BaseDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, **kwargs):
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout=dropout,)
|
||||
self.d_model = emb_dim
|
||||
self.nhead = kwargs.get("nhead", self.d_model // 64)
|
||||
self.nlayers = kwargs.get("nlayers", 2)
|
||||
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
|
||||
|
||||
self.pos_encoder = PositionalEncoding(self.d_model, dropout)
|
||||
layer = nn.TransformerDecoderLayer(d_model=self.d_model,
|
||||
nhead=self.nhead,
|
||||
dim_feedforward=self.dim_feedforward,
|
||||
dropout=dropout)
|
||||
self.model = nn.TransformerDecoder(layer, self.nlayers)
|
||||
self.classifier = nn.Linear(self.d_model, vocab_size)
|
||||
self.attn_proj = nn.Sequential(
|
||||
nn.Linear(self.attn_emb_dim, self.d_model),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.LayerNorm(self.d_model)
|
||||
)
|
||||
# self.attn_proj = lambda x: x
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def generate_square_subsequent_mask(self, max_length):
|
||||
mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
return mask
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"]
|
||||
attn_emb = input_dict["attn_emb"]
|
||||
attn_emb_len = input_dict["attn_emb_len"]
|
||||
cap_padding_mask = input_dict["cap_padding_mask"]
|
||||
|
||||
p_attn_emb = self.attn_proj(attn_emb)
|
||||
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
|
||||
word = word.to(attn_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
|
||||
embed = embed.transpose(0, 1) # [T, N, emb_dim]
|
||||
embed = self.pos_encoder(embed)
|
||||
|
||||
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
|
||||
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
|
||||
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=cap_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask)
|
||||
output = output.transpose(0, 1)
|
||||
output = {
|
||||
"embed": output,
|
||||
"logit": self.classifier(output),
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
class EventTransformerDecoder(TransformerDecoder):
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"] # index of word embeddings
|
||||
attn_emb = input_dict["attn_emb"]
|
||||
attn_emb_len = input_dict["attn_emb_len"]
|
||||
cap_padding_mask = input_dict["cap_padding_mask"]
|
||||
event_emb = input_dict["event"] # [N, emb_dim]
|
||||
|
||||
p_attn_emb = self.attn_proj(attn_emb)
|
||||
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
|
||||
word = word.to(attn_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
|
||||
|
||||
embed = embed.transpose(0, 1) # [T, N, emb_dim]
|
||||
embed += event_emb
|
||||
embed = self.pos_encoder(embed)
|
||||
|
||||
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
|
||||
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
|
||||
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=cap_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask)
|
||||
output = output.transpose(0, 1)
|
||||
output = {
|
||||
"embed": output,
|
||||
"logit": self.classifier(output),
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class KeywordProbTransformerDecoder(TransformerDecoder):
|
||||
|
||||
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, keyword_classes_num, **kwargs):
|
||||
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
||||
dropout, **kwargs)
|
||||
self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model)
|
||||
self.word_keyword_norm = nn.LayerNorm(self.d_model)
|
||||
|
||||
def forward(self, input_dict):
|
||||
word = input_dict["word"] # index of word embeddings
|
||||
attn_emb = input_dict["attn_emb"]
|
||||
attn_emb_len = input_dict["attn_emb_len"]
|
||||
cap_padding_mask = input_dict["cap_padding_mask"]
|
||||
keyword = input_dict["keyword"] # [N, keyword_classes_num]
|
||||
|
||||
p_attn_emb = self.attn_proj(attn_emb)
|
||||
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
|
||||
word = word.to(attn_emb.device)
|
||||
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
|
||||
|
||||
embed = embed.transpose(0, 1) # [T, N, emb_dim]
|
||||
embed += self.keyword_proj(keyword)
|
||||
embed = self.word_keyword_norm(embed)
|
||||
|
||||
embed = self.pos_encoder(embed)
|
||||
|
||||
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
|
||||
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
|
||||
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=cap_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask)
|
||||
output = output.transpose(0, 1)
|
||||
output = {
|
||||
"embed": output,
|
||||
"logit": self.classifier(output),
|
||||
}
|
||||
return output
|
||||
686
audio_to_text/captioning/models/encoder.py
Normal file
686
audio_to_text/captioning/models/encoder.py
Normal file
@@ -0,0 +1,686 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchaudio import transforms
|
||||
from torchlibrosa.augmentation import SpecAugmentation
|
||||
|
||||
from .utils import mean_with_lens, max_with_lens, \
|
||||
init, pack_wrapper, generate_length_mask, PositionalEncoding
|
||||
|
||||
|
||||
def init_layer(layer):
|
||||
"""Initialize a Linear or Convolutional layer. """
|
||||
nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
if hasattr(layer, 'bias'):
|
||||
if layer.bias is not None:
|
||||
layer.bias.data.fill_(0.)
|
||||
|
||||
|
||||
def init_bn(bn):
|
||||
"""Initialize a Batchnorm layer. """
|
||||
bn.bias.data.fill_(0.)
|
||||
bn.weight.data.fill_(1.)
|
||||
|
||||
|
||||
class BaseEncoder(nn.Module):
|
||||
|
||||
"""
|
||||
Encode the given audio into embedding
|
||||
Base encoder class, cannot be called directly
|
||||
All encoders should inherit from this class
|
||||
"""
|
||||
|
||||
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
|
||||
super(BaseEncoder, self).__init__()
|
||||
self.spec_dim = spec_dim
|
||||
self.fc_feat_dim = fc_feat_dim
|
||||
self.attn_feat_dim = attn_feat_dim
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
#########################
|
||||
# an encoder first encodes audio feature into embedding, obtaining
|
||||
# `encoded`: {
|
||||
# fc_embs: [N, fc_emb_dim],
|
||||
# attn_embs: [N, attn_max_len, attn_emb_dim],
|
||||
# attn_emb_lens: [N,]
|
||||
# }
|
||||
#########################
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Block2D(nn.Module):
|
||||
|
||||
def __init__(self, cin, cout, kernel_size=3, padding=1):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.BatchNorm2d(cin),
|
||||
nn.Conv2d(cin,
|
||||
cout,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
bias=False),
|
||||
nn.LeakyReLU(inplace=True, negative_slope=0.1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class LinearSoftPool(nn.Module):
|
||||
"""LinearSoftPool
|
||||
Linear softmax, takes logits and returns a probability, near to the actual maximum value.
|
||||
Taken from the paper:
|
||||
A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling
|
||||
https://arxiv.org/abs/1810.09050
|
||||
"""
|
||||
def __init__(self, pooldim=1):
|
||||
super().__init__()
|
||||
self.pooldim = pooldim
|
||||
|
||||
def forward(self, logits, time_decision):
|
||||
return (time_decision**2).sum(self.pooldim) / time_decision.sum(
|
||||
self.pooldim)
|
||||
|
||||
|
||||
class MeanPool(nn.Module):
|
||||
|
||||
def __init__(self, pooldim=1):
|
||||
super().__init__()
|
||||
self.pooldim = pooldim
|
||||
|
||||
def forward(self, logits, decision):
|
||||
return torch.mean(decision, dim=self.pooldim)
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
"""docstring for AttentionPool"""
|
||||
def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs):
|
||||
super().__init__()
|
||||
self.inputdim = inputdim
|
||||
self.outputdim = outputdim
|
||||
self.pooldim = pooldim
|
||||
self.transform = nn.Linear(inputdim, outputdim)
|
||||
self.activ = nn.Softmax(dim=self.pooldim)
|
||||
self.eps = 1e-7
|
||||
|
||||
def forward(self, logits, decision):
|
||||
# Input is (B, T, D)
|
||||
# B, T, D
|
||||
w = self.activ(torch.clamp(self.transform(logits), -15, 15))
|
||||
detect = (decision * w).sum(
|
||||
self.pooldim) / (w.sum(self.pooldim) + self.eps)
|
||||
# B, T, D
|
||||
return detect
|
||||
|
||||
|
||||
class MMPool(nn.Module):
|
||||
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
self.avgpool = nn.AvgPool2d(dims)
|
||||
self.maxpool = nn.MaxPool2d(dims)
|
||||
|
||||
def forward(self, x):
|
||||
return self.avgpool(x) + self.maxpool(x)
|
||||
|
||||
|
||||
def parse_poolingfunction(poolingfunction_name='mean', **kwargs):
|
||||
"""parse_poolingfunction
|
||||
A heler function to parse any temporal pooling
|
||||
Pooling is done on dimension 1
|
||||
:param poolingfunction_name:
|
||||
:param **kwargs:
|
||||
"""
|
||||
poolingfunction_name = poolingfunction_name.lower()
|
||||
if poolingfunction_name == 'mean':
|
||||
return MeanPool(pooldim=1)
|
||||
elif poolingfunction_name == 'linear':
|
||||
return LinearSoftPool(pooldim=1)
|
||||
elif poolingfunction_name == 'attention':
|
||||
return AttentionPool(inputdim=kwargs['inputdim'],
|
||||
outputdim=kwargs['outputdim'])
|
||||
|
||||
|
||||
def embedding_pooling(x, lens, pooling="mean"):
|
||||
if pooling == "max":
|
||||
fc_embs = max_with_lens(x, lens)
|
||||
elif pooling == "mean":
|
||||
fc_embs = mean_with_lens(x, lens)
|
||||
elif pooling == "mean+max":
|
||||
x_mean = mean_with_lens(x, lens)
|
||||
x_max = max_with_lens(x, lens)
|
||||
fc_embs = x_mean + x_max
|
||||
elif pooling == "last":
|
||||
indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
|
||||
# indices: [N, 1, hidden]
|
||||
fc_embs = torch.gather(x, 1, indices).squeeze(1)
|
||||
else:
|
||||
raise Exception(f"pooling method {pooling} not support")
|
||||
return fc_embs
|
||||
|
||||
|
||||
class Cdur5Encoder(BaseEncoder):
|
||||
|
||||
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
|
||||
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
|
||||
self.pooling = pooling
|
||||
self.features = nn.Sequential(
|
||||
Block2D(1, 32),
|
||||
nn.LPPool2d(4, (2, 4)),
|
||||
Block2D(32, 128),
|
||||
Block2D(128, 128),
|
||||
nn.LPPool2d(4, (2, 4)),
|
||||
Block2D(128, 128),
|
||||
Block2D(128, 128),
|
||||
nn.LPPool2d(4, (1, 4)),
|
||||
nn.Dropout(0.3),
|
||||
)
|
||||
with torch.no_grad():
|
||||
rnn_input_dim = self.features(
|
||||
torch.randn(1, 1, 500, spec_dim)).shape
|
||||
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1]
|
||||
|
||||
self.gru = nn.GRU(rnn_input_dim,
|
||||
128,
|
||||
bidirectional=True,
|
||||
batch_first=True)
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
x = input_dict["spec"]
|
||||
lens = input_dict["spec_len"]
|
||||
if "upsample" not in input_dict:
|
||||
input_dict["upsample"] = False
|
||||
lens = torch.as_tensor(copy.deepcopy(lens))
|
||||
N, T, _ = x.shape
|
||||
x = x.unsqueeze(1)
|
||||
x = self.features(x)
|
||||
x = x.transpose(1, 2).contiguous().flatten(-2)
|
||||
x, _ = self.gru(x)
|
||||
if input_dict["upsample"]:
|
||||
x = nn.functional.interpolate(
|
||||
x.transpose(1, 2),
|
||||
T,
|
||||
mode='linear',
|
||||
align_corners=False).transpose(1, 2)
|
||||
else:
|
||||
lens //= 4
|
||||
attn_emb = x
|
||||
fc_emb = embedding_pooling(x, lens, self.pooling)
|
||||
return {
|
||||
"attn_emb": attn_emb,
|
||||
"fc_emb": fc_emb,
|
||||
"attn_emb_len": lens
|
||||
}
|
||||
|
||||
|
||||
def conv_conv_block(in_channel, out_channel):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_channel,
|
||||
out_channel,
|
||||
kernel_size=3,
|
||||
bias=False,
|
||||
padding=1),
|
||||
nn.BatchNorm2d(out_channel),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(out_channel,
|
||||
out_channel,
|
||||
kernel_size=3,
|
||||
bias=False,
|
||||
padding=1),
|
||||
nn.BatchNorm2d(out_channel),
|
||||
nn.ReLU(True)
|
||||
)
|
||||
|
||||
|
||||
class Cdur8Encoder(BaseEncoder):
|
||||
|
||||
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
|
||||
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
|
||||
self.pooling = pooling
|
||||
self.features = nn.Sequential(
|
||||
conv_conv_block(1, 64),
|
||||
MMPool((2, 2)),
|
||||
nn.Dropout(0.2, True),
|
||||
conv_conv_block(64, 128),
|
||||
MMPool((2, 2)),
|
||||
nn.Dropout(0.2, True),
|
||||
conv_conv_block(128, 256),
|
||||
MMPool((1, 2)),
|
||||
nn.Dropout(0.2, True),
|
||||
conv_conv_block(256, 512),
|
||||
MMPool((1, 2)),
|
||||
nn.Dropout(0.2, True),
|
||||
nn.AdaptiveAvgPool2d((None, 1)),
|
||||
)
|
||||
self.init_bn = nn.BatchNorm2d(spec_dim)
|
||||
self.embedding = nn.Linear(512, 512)
|
||||
self.gru = nn.GRU(512, 256, bidirectional=True, batch_first=True)
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
x = input_dict["spec"]
|
||||
lens = input_dict["spec_len"]
|
||||
lens = torch.as_tensor(copy.deepcopy(lens))
|
||||
x = x.unsqueeze(1) # B x 1 x T x D
|
||||
x = x.transpose(1, 3)
|
||||
x = self.init_bn(x)
|
||||
x = x.transpose(1, 3)
|
||||
x = self.features(x)
|
||||
x = x.transpose(1, 2).contiguous().flatten(-2)
|
||||
x = F.dropout(x, p=0.5, training=self.training)
|
||||
x = F.relu_(self.embedding(x))
|
||||
x, _ = self.gru(x)
|
||||
attn_emb = x
|
||||
lens //= 4
|
||||
fc_emb = embedding_pooling(x, lens, self.pooling)
|
||||
return {
|
||||
"attn_emb": attn_emb,
|
||||
"fc_emb": fc_emb,
|
||||
"attn_emb_len": lens
|
||||
}
|
||||
|
||||
|
||||
class Cnn10Encoder(BaseEncoder):
|
||||
|
||||
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
|
||||
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
|
||||
self.features = nn.Sequential(
|
||||
conv_conv_block(1, 64),
|
||||
nn.AvgPool2d((2, 2)),
|
||||
nn.Dropout(0.2, True),
|
||||
conv_conv_block(64, 128),
|
||||
nn.AvgPool2d((2, 2)),
|
||||
nn.Dropout(0.2, True),
|
||||
conv_conv_block(128, 256),
|
||||
nn.AvgPool2d((2, 2)),
|
||||
nn.Dropout(0.2, True),
|
||||
conv_conv_block(256, 512),
|
||||
nn.AvgPool2d((2, 2)),
|
||||
nn.Dropout(0.2, True),
|
||||
nn.AdaptiveAvgPool2d((None, 1)),
|
||||
)
|
||||
self.init_bn = nn.BatchNorm2d(spec_dim)
|
||||
self.embedding = nn.Linear(512, 512)
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
x = input_dict["spec"]
|
||||
lens = input_dict["spec_len"]
|
||||
lens = torch.as_tensor(copy.deepcopy(lens))
|
||||
x = x.unsqueeze(1) # [N, 1, T, D]
|
||||
x = x.transpose(1, 3)
|
||||
x = self.init_bn(x)
|
||||
x = x.transpose(1, 3)
|
||||
x = self.features(x) # [N, 512, T/16, 1]
|
||||
x = x.transpose(1, 2).contiguous().flatten(-2) # [N, T/16, 512]
|
||||
attn_emb = x
|
||||
lens //= 16
|
||||
fc_emb = embedding_pooling(x, lens, "mean+max")
|
||||
fc_emb = F.dropout(fc_emb, p=0.5, training=self.training)
|
||||
fc_emb = self.embedding(fc_emb)
|
||||
fc_emb = F.relu_(fc_emb)
|
||||
return {
|
||||
"attn_emb": attn_emb,
|
||||
"fc_emb": fc_emb,
|
||||
"attn_emb_len": lens
|
||||
}
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
|
||||
super(ConvBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3), stride=(1, 1),
|
||||
padding=(1, 1), bias=False)
|
||||
|
||||
self.conv2 = nn.Conv2d(in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3), stride=(1, 1),
|
||||
padding=(1, 1), bias=False)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
init_layer(self.conv1)
|
||||
init_layer(self.conv2)
|
||||
init_bn(self.bn1)
|
||||
init_bn(self.bn2)
|
||||
|
||||
|
||||
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
|
||||
|
||||
x = input
|
||||
x = F.relu_(self.bn1(self.conv1(x)))
|
||||
x = F.relu_(self.bn2(self.conv2(x)))
|
||||
if pool_type == 'max':
|
||||
x = F.max_pool2d(x, kernel_size=pool_size)
|
||||
elif pool_type == 'avg':
|
||||
x = F.avg_pool2d(x, kernel_size=pool_size)
|
||||
elif pool_type == 'avg+max':
|
||||
x1 = F.avg_pool2d(x, kernel_size=pool_size)
|
||||
x2 = F.max_pool2d(x, kernel_size=pool_size)
|
||||
x = x1 + x2
|
||||
else:
|
||||
raise Exception('Incorrect argument!')
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Cnn14Encoder(nn.Module):
|
||||
def __init__(self, sample_rate=32000):
|
||||
super().__init__()
|
||||
sr_to_fmax = {
|
||||
32000: 14000,
|
||||
16000: 8000
|
||||
}
|
||||
# Logmel spectrogram extractor
|
||||
self.melspec_extractor = transforms.MelSpectrogram(
|
||||
sample_rate=sample_rate,
|
||||
n_fft=32 * sample_rate // 1000,
|
||||
win_length=32 * sample_rate // 1000,
|
||||
hop_length=10 * sample_rate // 1000,
|
||||
f_min=50,
|
||||
f_max=sr_to_fmax[sample_rate],
|
||||
n_mels=64,
|
||||
norm="slaney",
|
||||
mel_scale="slaney"
|
||||
)
|
||||
self.hop_length = 10 * sample_rate // 1000
|
||||
self.db_transform = transforms.AmplitudeToDB()
|
||||
# Spec augmenter
|
||||
self.spec_augmenter = SpecAugmentation(time_drop_width=64,
|
||||
time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2)
|
||||
|
||||
self.bn0 = nn.BatchNorm2d(64)
|
||||
|
||||
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
|
||||
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
|
||||
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
|
||||
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
|
||||
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
|
||||
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
|
||||
|
||||
self.downsample_ratio = 32
|
||||
|
||||
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
||||
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
init_bn(self.bn0)
|
||||
init_layer(self.fc1)
|
||||
|
||||
def load_pretrained(self, pretrained):
|
||||
checkpoint = torch.load(pretrained, map_location="cpu")
|
||||
|
||||
if "model" in checkpoint:
|
||||
state_keys = checkpoint["model"].keys()
|
||||
backbone = False
|
||||
for key in state_keys:
|
||||
if key.startswith("backbone."):
|
||||
backbone = True
|
||||
break
|
||||
|
||||
if backbone: # COLA
|
||||
state_dict = {}
|
||||
for key, value in checkpoint["model"].items():
|
||||
if key.startswith("backbone."):
|
||||
model_key = key.replace("backbone.", "")
|
||||
state_dict[model_key] = value
|
||||
else: # PANNs
|
||||
state_dict = checkpoint["model"]
|
||||
elif "state_dict" in checkpoint: # CLAP
|
||||
state_dict = checkpoint["state_dict"]
|
||||
state_dict_keys = list(filter(
|
||||
lambda x: "audio_encoder" in x, state_dict.keys()))
|
||||
state_dict = {
|
||||
key.replace('audio_encoder.', ''): state_dict[key]
|
||||
for key in state_dict_keys
|
||||
}
|
||||
else:
|
||||
raise Exception("Unkown checkpoint format")
|
||||
|
||||
model_dict = self.state_dict()
|
||||
pretrained_dict = {
|
||||
k: v for k, v in state_dict.items() if (k in model_dict) and (
|
||||
model_dict[k].shape == v.shape)
|
||||
}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.load_state_dict(model_dict, strict=True)
|
||||
|
||||
def forward(self, input_dict):
|
||||
"""
|
||||
Input: (batch_size, n_samples)"""
|
||||
waveform = input_dict["wav"]
|
||||
wave_length = input_dict["wav_len"]
|
||||
specaug = input_dict["specaug"]
|
||||
x = self.melspec_extractor(waveform)
|
||||
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
|
||||
x = x.transpose(1, 2)
|
||||
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
|
||||
|
||||
# SpecAugment
|
||||
if self.training and specaug:
|
||||
x = self.spec_augmenter(x)
|
||||
|
||||
x = x.transpose(1, 3)
|
||||
x = self.bn0(x)
|
||||
x = x.transpose(1, 3)
|
||||
|
||||
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
|
||||
x = F.dropout(x, p=0.2, training=self.training)
|
||||
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
|
||||
x = F.dropout(x, p=0.2, training=self.training)
|
||||
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
|
||||
x = F.dropout(x, p=0.2, training=self.training)
|
||||
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
|
||||
x = F.dropout(x, p=0.2, training=self.training)
|
||||
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
|
||||
x = F.dropout(x, p=0.2, training=self.training)
|
||||
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
|
||||
x = F.dropout(x, p=0.2, training=self.training)
|
||||
x = torch.mean(x, dim=3)
|
||||
attn_emb = x.transpose(1, 2)
|
||||
|
||||
wave_length = torch.as_tensor(wave_length)
|
||||
feat_length = torch.div(wave_length, self.hop_length,
|
||||
rounding_mode="floor") + 1
|
||||
feat_length = torch.div(feat_length, self.downsample_ratio,
|
||||
rounding_mode="floor")
|
||||
x_max = max_with_lens(attn_emb, feat_length)
|
||||
x_mean = mean_with_lens(attn_emb, feat_length)
|
||||
x = x_max + x_mean
|
||||
x = F.dropout(x, p=0.5, training=self.training)
|
||||
x = F.relu_(self.fc1(x))
|
||||
fc_emb = F.dropout(x, p=0.5, training=self.training)
|
||||
|
||||
output_dict = {
|
||||
'fc_emb': fc_emb,
|
||||
'attn_emb': attn_emb,
|
||||
'attn_emb_len': feat_length
|
||||
}
|
||||
|
||||
return output_dict
|
||||
|
||||
|
||||
class RnnEncoder(BaseEncoder):
|
||||
|
||||
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim,
|
||||
pooling="mean", **kwargs):
|
||||
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
|
||||
self.pooling = pooling
|
||||
self.hidden_size = kwargs.get('hidden_size', 512)
|
||||
self.bidirectional = kwargs.get('bidirectional', False)
|
||||
self.num_layers = kwargs.get('num_layers', 1)
|
||||
self.dropout = kwargs.get('dropout', 0.2)
|
||||
self.rnn_type = kwargs.get('rnn_type', "GRU")
|
||||
self.in_bn = kwargs.get('in_bn', False)
|
||||
self.embed_dim = self.hidden_size * (self.bidirectional + 1)
|
||||
self.network = getattr(nn, self.rnn_type)(
|
||||
attn_feat_dim,
|
||||
self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
bidirectional=self.bidirectional,
|
||||
dropout=self.dropout,
|
||||
batch_first=True)
|
||||
if self.in_bn:
|
||||
self.bn = nn.BatchNorm1d(self.embed_dim)
|
||||
self.apply(init)
|
||||
|
||||
def forward(self, input_dict):
|
||||
x = input_dict["attn"]
|
||||
lens = input_dict["attn_len"]
|
||||
lens = torch.as_tensor(lens)
|
||||
# x: [N, T, E]
|
||||
if self.in_bn:
|
||||
x = pack_wrapper(self.bn, x, lens)
|
||||
out = pack_wrapper(self.network, x, lens)
|
||||
# out: [N, T, hidden]
|
||||
attn_emb = out
|
||||
fc_emb = embedding_pooling(out, lens, self.pooling)
|
||||
return {
|
||||
"attn_emb": attn_emb,
|
||||
"fc_emb": fc_emb,
|
||||
"attn_emb_len": lens
|
||||
}
|
||||
|
||||
|
||||
class Cnn14RnnEncoder(nn.Module):
|
||||
def __init__(self, sample_rate=32000, pretrained=None,
|
||||
freeze_cnn=False, freeze_cnn_bn=False,
|
||||
pooling="mean", **kwargs):
|
||||
super().__init__()
|
||||
self.cnn = Cnn14Encoder(sample_rate)
|
||||
self.rnn = RnnEncoder(64, 2048, 2048, pooling, **kwargs)
|
||||
if pretrained is not None:
|
||||
self.cnn.load_pretrained(pretrained)
|
||||
if freeze_cnn:
|
||||
assert pretrained is not None, "cnn is not pretrained but frozen"
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.freeze_cnn_bn = freeze_cnn_bn
|
||||
|
||||
def train(self, mode):
|
||||
super().train(mode=mode)
|
||||
if self.freeze_cnn_bn:
|
||||
def bn_eval(module):
|
||||
class_name = module.__class__.__name__
|
||||
if class_name.find("BatchNorm") != -1:
|
||||
module.eval()
|
||||
self.cnn.apply(bn_eval)
|
||||
return self
|
||||
|
||||
def forward(self, input_dict):
|
||||
output_dict = self.cnn(input_dict)
|
||||
output_dict["attn"] = output_dict["attn_emb"]
|
||||
output_dict["attn_len"] = output_dict["attn_emb_len"]
|
||||
del output_dict["attn_emb"], output_dict["attn_emb_len"]
|
||||
output_dict = self.rnn(output_dict)
|
||||
return output_dict
|
||||
|
||||
|
||||
class TransformerEncoder(BaseEncoder):
|
||||
|
||||
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, d_model, **kwargs):
|
||||
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
|
||||
self.d_model = d_model
|
||||
dropout = kwargs.get("dropout", 0.2)
|
||||
self.nhead = kwargs.get("nhead", self.d_model // 64)
|
||||
self.nlayers = kwargs.get("nlayers", 2)
|
||||
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
|
||||
|
||||
self.attn_proj = nn.Sequential(
|
||||
nn.Linear(attn_feat_dim, self.d_model),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.LayerNorm(self.d_model)
|
||||
)
|
||||
layer = nn.TransformerEncoderLayer(d_model=self.d_model,
|
||||
nhead=self.nhead,
|
||||
dim_feedforward=self.dim_feedforward,
|
||||
dropout=dropout)
|
||||
self.model = nn.TransformerEncoder(layer, self.nlayers)
|
||||
self.cls_token = nn.Parameter(torch.zeros(d_model))
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, input_dict):
|
||||
attn_feat = input_dict["attn"]
|
||||
attn_feat_len = input_dict["attn_len"]
|
||||
attn_feat_len = torch.as_tensor(attn_feat_len)
|
||||
|
||||
attn_feat = self.attn_proj(attn_feat) # [bs, T, d_model]
|
||||
|
||||
cls_emb = self.cls_token.reshape(1, 1, self.d_model).repeat(
|
||||
attn_feat.size(0), 1, 1)
|
||||
attn_feat = torch.cat((cls_emb, attn_feat), dim=1)
|
||||
attn_feat = attn_feat.transpose(0, 1)
|
||||
|
||||
attn_feat_len += 1
|
||||
src_key_padding_mask = ~generate_length_mask(
|
||||
attn_feat_len, attn_feat.size(0)).to(attn_feat.device)
|
||||
output = self.model(attn_feat, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
attn_emb = output.transpose(0, 1)
|
||||
fc_emb = attn_emb[:, 0]
|
||||
return {
|
||||
"attn_emb": attn_emb,
|
||||
"fc_emb": fc_emb,
|
||||
"attn_emb_len": attn_feat_len
|
||||
}
|
||||
|
||||
|
||||
class Cnn14TransformerEncoder(nn.Module):
|
||||
def __init__(self, sample_rate=32000, pretrained=None,
|
||||
freeze_cnn=False, freeze_cnn_bn=False,
|
||||
d_model="mean", **kwargs):
|
||||
super().__init__()
|
||||
self.cnn = Cnn14Encoder(sample_rate)
|
||||
self.trm = TransformerEncoder(64, 2048, 2048, d_model, **kwargs)
|
||||
if pretrained is not None:
|
||||
self.cnn.load_pretrained(pretrained)
|
||||
if freeze_cnn:
|
||||
assert pretrained is not None, "cnn is not pretrained but frozen"
|
||||
for param in self.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
self.freeze_cnn_bn = freeze_cnn_bn
|
||||
|
||||
def train(self, mode):
|
||||
super().train(mode=mode)
|
||||
if self.freeze_cnn_bn:
|
||||
def bn_eval(module):
|
||||
class_name = module.__class__.__name__
|
||||
if class_name.find("BatchNorm") != -1:
|
||||
module.eval()
|
||||
self.cnn.apply(bn_eval)
|
||||
return self
|
||||
|
||||
def forward(self, input_dict):
|
||||
output_dict = self.cnn(input_dict)
|
||||
output_dict["attn"] = output_dict["attn_emb"]
|
||||
output_dict["attn_len"] = output_dict["attn_emb_len"]
|
||||
del output_dict["attn_emb"], output_dict["attn_emb_len"]
|
||||
output_dict = self.trm(output_dict)
|
||||
return output_dict
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
265
audio_to_text/captioning/models/transformer_model.py
Normal file
265
audio_to_text/captioning/models/transformer_model.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import CaptionModel
|
||||
from .utils import repeat_tensor
|
||||
import audio_to_text.captioning.models.decoder
|
||||
|
||||
|
||||
class TransformerModel(CaptionModel):
|
||||
|
||||
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
||||
if not hasattr(self, "compatible_decoders"):
|
||||
self.compatible_decoders = (
|
||||
audio_to_text.captioning.models.decoder.TransformerDecoder,
|
||||
)
|
||||
super().__init__(encoder, decoder, **kwargs)
|
||||
|
||||
def seq_forward(self, input_dict):
|
||||
cap = input_dict["cap"]
|
||||
cap_padding_mask = (cap == self.pad_idx).to(cap.device)
|
||||
cap_padding_mask = cap_padding_mask[:, :-1]
|
||||
output = self.decoder(
|
||||
{
|
||||
"word": cap[:, :-1],
|
||||
"attn_emb": input_dict["attn_emb"],
|
||||
"attn_emb_len": input_dict["attn_emb_len"],
|
||||
"cap_padding_mask": cap_padding_mask
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def prepare_decoder_input(self, input_dict, output):
|
||||
decoder_input = {
|
||||
"attn_emb": input_dict["attn_emb"],
|
||||
"attn_emb_len": input_dict["attn_emb_len"]
|
||||
}
|
||||
t = input_dict["t"]
|
||||
|
||||
###############
|
||||
# determine input word
|
||||
################
|
||||
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
|
||||
word = input_dict["cap"][:, :t+1]
|
||||
else:
|
||||
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
|
||||
if t == 0:
|
||||
word = start_word
|
||||
else:
|
||||
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
|
||||
# word: [N, T]
|
||||
decoder_input["word"] = word
|
||||
|
||||
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
|
||||
decoder_input["cap_padding_mask"] = cap_padding_mask
|
||||
return decoder_input
|
||||
|
||||
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
||||
decoder_input = {}
|
||||
t = input_dict["t"]
|
||||
i = input_dict["sample_idx"]
|
||||
beam_size = input_dict["beam_size"]
|
||||
###############
|
||||
# prepare attn embeds
|
||||
################
|
||||
if t == 0:
|
||||
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
|
||||
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
|
||||
output_i["attn_emb"] = attn_emb
|
||||
output_i["attn_emb_len"] = attn_emb_len
|
||||
decoder_input["attn_emb"] = output_i["attn_emb"]
|
||||
decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
|
||||
###############
|
||||
# determine input word
|
||||
################
|
||||
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
|
||||
if t == 0:
|
||||
word = start_word
|
||||
else:
|
||||
word = torch.cat((start_word, output_i["seq"]), dim=-1)
|
||||
decoder_input["word"] = word
|
||||
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
|
||||
decoder_input["cap_padding_mask"] = cap_padding_mask
|
||||
|
||||
return decoder_input
|
||||
|
||||
|
||||
class M2TransformerModel(CaptionModel):
|
||||
|
||||
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
||||
if not hasattr(self, "compatible_decoders"):
|
||||
self.compatible_decoders = (
|
||||
captioning.models.decoder.M2TransformerDecoder,
|
||||
)
|
||||
super().__init__(encoder, decoder, **kwargs)
|
||||
self.check_encoder_compatibility()
|
||||
|
||||
def check_encoder_compatibility(self):
|
||||
assert isinstance(self.encoder, captioning.models.encoder.M2TransformerEncoder), \
|
||||
f"only M2TransformerModel is compatible with {self.__class__.__name__}"
|
||||
|
||||
|
||||
def seq_forward(self, input_dict):
|
||||
cap = input_dict["cap"]
|
||||
output = self.decoder(
|
||||
{
|
||||
"word": cap[:, :-1],
|
||||
"attn_emb": input_dict["attn_emb"],
|
||||
"attn_emb_mask": input_dict["attn_emb_mask"],
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def prepare_decoder_input(self, input_dict, output):
|
||||
decoder_input = {
|
||||
"attn_emb": input_dict["attn_emb"],
|
||||
"attn_emb_mask": input_dict["attn_emb_mask"]
|
||||
}
|
||||
t = input_dict["t"]
|
||||
|
||||
###############
|
||||
# determine input word
|
||||
################
|
||||
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
|
||||
word = input_dict["cap"][:, :t+1]
|
||||
else:
|
||||
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
|
||||
if t == 0:
|
||||
word = start_word
|
||||
else:
|
||||
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
|
||||
# word: [N, T]
|
||||
decoder_input["word"] = word
|
||||
|
||||
return decoder_input
|
||||
|
||||
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
||||
decoder_input = {}
|
||||
t = input_dict["t"]
|
||||
i = input_dict["sample_idx"]
|
||||
beam_size = input_dict["beam_size"]
|
||||
###############
|
||||
# prepare attn embeds
|
||||
################
|
||||
if t == 0:
|
||||
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
|
||||
attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size)
|
||||
output_i["attn_emb"] = attn_emb
|
||||
output_i["attn_emb_mask"] = attn_emb_mask
|
||||
decoder_input["attn_emb"] = output_i["attn_emb"]
|
||||
decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"]
|
||||
###############
|
||||
# determine input word
|
||||
################
|
||||
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
|
||||
if t == 0:
|
||||
word = start_word
|
||||
else:
|
||||
word = torch.cat((start_word, output_i["seq"]), dim=-1)
|
||||
decoder_input["word"] = word
|
||||
|
||||
return decoder_input
|
||||
|
||||
|
||||
class EventEncoder(nn.Module):
|
||||
"""
|
||||
Encode the Label information in AudioCaps and AudioSet
|
||||
"""
|
||||
def __init__(self, emb_dim, vocab_size=527):
|
||||
super(EventEncoder, self).__init__()
|
||||
self.label_embedding = nn.Parameter(
|
||||
torch.randn((vocab_size, emb_dim)), requires_grad=True)
|
||||
|
||||
def forward(self, word_idxs):
|
||||
indices = word_idxs / word_idxs.sum(dim=1, keepdim=True)
|
||||
embeddings = indices @ self.label_embedding
|
||||
return embeddings
|
||||
|
||||
|
||||
class EventCondTransformerModel(TransformerModel):
|
||||
|
||||
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
||||
if not hasattr(self, "compatible_decoders"):
|
||||
self.compatible_decoders = (
|
||||
captioning.models.decoder.EventTransformerDecoder,
|
||||
)
|
||||
super().__init__(encoder, decoder, **kwargs)
|
||||
self.label_encoder = EventEncoder(decoder.emb_dim, 527)
|
||||
self.train_forward_keys += ["events"]
|
||||
self.inference_forward_keys += ["events"]
|
||||
|
||||
# def seq_forward(self, input_dict):
|
||||
# cap = input_dict["cap"]
|
||||
# cap_padding_mask = (cap == self.pad_idx).to(cap.device)
|
||||
# cap_padding_mask = cap_padding_mask[:, :-1]
|
||||
# output = self.decoder(
|
||||
# {
|
||||
# "word": cap[:, :-1],
|
||||
# "attn_emb": input_dict["attn_emb"],
|
||||
# "attn_emb_len": input_dict["attn_emb_len"],
|
||||
# "cap_padding_mask": cap_padding_mask
|
||||
# }
|
||||
# )
|
||||
# return output
|
||||
|
||||
def prepare_decoder_input(self, input_dict, output):
|
||||
decoder_input = super().prepare_decoder_input(input_dict, output)
|
||||
decoder_input["events"] = self.label_encoder(input_dict["events"])
|
||||
return decoder_input
|
||||
|
||||
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
||||
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
|
||||
t = input_dict["t"]
|
||||
i = input_dict["sample_idx"]
|
||||
beam_size = input_dict["beam_size"]
|
||||
if t == 0:
|
||||
output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size)
|
||||
decoder_input["events"] = output_i["events"]
|
||||
return decoder_input
|
||||
|
||||
|
||||
class KeywordCondTransformerModel(TransformerModel):
|
||||
|
||||
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
||||
if not hasattr(self, "compatible_decoders"):
|
||||
self.compatible_decoders = (
|
||||
captioning.models.decoder.KeywordProbTransformerDecoder,
|
||||
)
|
||||
super().__init__(encoder, decoder, **kwargs)
|
||||
self.train_forward_keys += ["keyword"]
|
||||
self.inference_forward_keys += ["keyword"]
|
||||
|
||||
def seq_forward(self, input_dict):
|
||||
cap = input_dict["cap"]
|
||||
cap_padding_mask = (cap == self.pad_idx).to(cap.device)
|
||||
cap_padding_mask = cap_padding_mask[:, :-1]
|
||||
keyword = input_dict["keyword"]
|
||||
output = self.decoder(
|
||||
{
|
||||
"word": cap[:, :-1],
|
||||
"attn_emb": input_dict["attn_emb"],
|
||||
"attn_emb_len": input_dict["attn_emb_len"],
|
||||
"keyword": keyword,
|
||||
"cap_padding_mask": cap_padding_mask
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def prepare_decoder_input(self, input_dict, output):
|
||||
decoder_input = super().prepare_decoder_input(input_dict, output)
|
||||
decoder_input["keyword"] = input_dict["keyword"]
|
||||
return decoder_input
|
||||
|
||||
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
||||
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
|
||||
t = input_dict["t"]
|
||||
i = input_dict["sample_idx"]
|
||||
beam_size = input_dict["beam_size"]
|
||||
if t == 0:
|
||||
output_i["keyword"] = repeat_tensor(input_dict["keyword"][i],
|
||||
beam_size)
|
||||
decoder_input["keyword"] = output_i["keyword"]
|
||||
return decoder_input
|
||||
|
||||
132
audio_to_text/captioning/models/utils.py
Normal file
132
audio_to_text/captioning/models/utils.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
||||
|
||||
|
||||
def sort_pack_padded_sequence(input, lengths):
|
||||
sorted_lengths, indices = torch.sort(lengths, descending=True)
|
||||
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
|
||||
inv_ix = indices.clone()
|
||||
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
|
||||
return tmp, inv_ix
|
||||
|
||||
def pad_unsort_packed_sequence(input, inv_ix):
|
||||
tmp, _ = pad_packed_sequence(input, batch_first=True)
|
||||
tmp = tmp[inv_ix]
|
||||
return tmp
|
||||
|
||||
def pack_wrapper(module, attn_feats, attn_feat_lens):
|
||||
packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
|
||||
if isinstance(module, torch.nn.RNNBase):
|
||||
return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
|
||||
else:
|
||||
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
|
||||
|
||||
def generate_length_mask(lens, max_length=None):
|
||||
lens = torch.as_tensor(lens)
|
||||
N = lens.size(0)
|
||||
if max_length is None:
|
||||
max_length = max(lens)
|
||||
idxs = torch.arange(max_length).repeat(N).view(N, max_length)
|
||||
idxs = idxs.to(lens.device)
|
||||
mask = (idxs < lens.view(-1, 1))
|
||||
return mask
|
||||
|
||||
def mean_with_lens(features, lens):
|
||||
"""
|
||||
features: [N, T, ...] (assume the second dimension represents length)
|
||||
lens: [N,]
|
||||
"""
|
||||
lens = torch.as_tensor(lens)
|
||||
if max(lens) != features.size(1):
|
||||
max_length = features.size(1)
|
||||
mask = generate_length_mask(lens, max_length)
|
||||
else:
|
||||
mask = generate_length_mask(lens)
|
||||
mask = mask.to(features.device) # [N, T]
|
||||
|
||||
while mask.ndim < features.ndim:
|
||||
mask = mask.unsqueeze(-1)
|
||||
feature_mean = features * mask
|
||||
feature_mean = feature_mean.sum(1)
|
||||
while lens.ndim < feature_mean.ndim:
|
||||
lens = lens.unsqueeze(1)
|
||||
feature_mean = feature_mean / lens.to(features.device)
|
||||
# feature_mean = features * mask.unsqueeze(-1)
|
||||
# feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
|
||||
return feature_mean
|
||||
|
||||
def max_with_lens(features, lens):
|
||||
"""
|
||||
features: [N, T, ...] (assume the second dimension represents length)
|
||||
lens: [N,]
|
||||
"""
|
||||
lens = torch.as_tensor(lens)
|
||||
mask = generate_length_mask(lens).to(features.device) # [N, T]
|
||||
|
||||
feature_max = features.clone()
|
||||
feature_max[~mask] = float("-inf")
|
||||
feature_max, _ = feature_max.max(1)
|
||||
return feature_max
|
||||
|
||||
def repeat_tensor(x, n):
|
||||
return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
|
||||
|
||||
def init(m, method="kaiming"):
|
||||
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
|
||||
if method == "kaiming":
|
||||
nn.init.kaiming_uniform_(m.weight)
|
||||
elif method == "xavier":
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
else:
|
||||
raise Exception(f"initialization method {method} not supported")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
if method == "kaiming":
|
||||
nn.init.kaiming_uniform_(m.weight)
|
||||
elif method == "xavier":
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
else:
|
||||
raise Exception(f"initialization method {method} not supported")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Embedding):
|
||||
if method == "kaiming":
|
||||
nn.init.kaiming_uniform_(m.weight)
|
||||
elif method == "xavier":
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
else:
|
||||
raise Exception(f"initialization method {method} not supported")
|
||||
|
||||
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
|
||||
def __init__(self, d_model, dropout=0.1, max_len=100):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
|
||||
(-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
# self.register_buffer("pe", pe)
|
||||
self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
|
||||
|
||||
def forward(self, x):
|
||||
# x: [T, N, E]
|
||||
x = x + self.pe[:x.size(0), :]
|
||||
return self.dropout(x)
|
||||
19
audio_to_text/captioning/utils/README.md
Normal file
19
audio_to_text/captioning/utils/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Utils
|
||||
|
||||
Scripts in this directory are used as utility functions.
|
||||
|
||||
## BERT Pretrained Embeddings
|
||||
|
||||
You can load pretrained word embeddings in Google [BERT](https://github.com/google-research/bert#pre-trained-models) instead of training word embeddings from scratch. The scripts in `utils/bert` need a BERT server in the background. We use BERT server from [bert-as-service](https://github.com/hanxiao/bert-as-service).
|
||||
|
||||
To use bert-as-service, you need to first install the repository. It is recommended that you create a new environment with Tensorflow 1.3 to run BERT server since it is incompatible with Tensorflow 2.x.
|
||||
|
||||
After successful installation of [bert-as-service](https://github.com/hanxiao/bert-as-service), downloading and running the BERT server needs to execute:
|
||||
|
||||
```bash
|
||||
bash scripts/prepare_bert_server.sh <path-to-server> <num-workers> zh
|
||||
```
|
||||
|
||||
By default, server based on BERT base Chinese model is running in the background. You can change to other models by changing corresponding model name and path in `scripts/prepare_bert_server.sh`.
|
||||
|
||||
To extract BERT word embeddings, you need to execute `utils/bert/create_word_embedding.py`.
|
||||
0
audio_to_text/captioning/utils/__init__.py
Normal file
0
audio_to_text/captioning/utils/__init__.py
Normal file
89
audio_to_text/captioning/utils/bert/create_sent_embedding.py
Normal file
89
audio_to_text/captioning/utils/bert/create_sent_embedding.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pickle
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class EmbeddingExtractor(object):
|
||||
|
||||
def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
lang2model = {
|
||||
"zh": "distiluse-base-multilingual-cased",
|
||||
"en": "bert-base-nli-mean-tokens"
|
||||
}
|
||||
lang = "zh" if zh else "en"
|
||||
model = SentenceTransformer(lang2model[lang])
|
||||
|
||||
self.extract(caption_file, model, output, dev)
|
||||
|
||||
def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"):
|
||||
from bert_serving.client import BertClient
|
||||
client = BertClient(ip)
|
||||
|
||||
self.extract(caption_file, client, output, dev)
|
||||
|
||||
def extract(self, caption_file: str, model, output, dev: bool):
|
||||
caption_df = pd.read_json(caption_file, dtype={"key": str})
|
||||
embeddings = {}
|
||||
|
||||
if dev:
|
||||
with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
|
||||
for idx, row in caption_df.iterrows():
|
||||
caption = row["caption"]
|
||||
key = row["key"]
|
||||
cap_idx = row["caption_index"]
|
||||
embedding = model.encode([caption])
|
||||
embedding = np.array(embedding).reshape(-1)
|
||||
embeddings[f"{key}_{cap_idx}"] = embedding
|
||||
pbar.update()
|
||||
|
||||
else:
|
||||
dump = {}
|
||||
|
||||
with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
|
||||
for idx, row in caption_df.iterrows():
|
||||
key = row["key"]
|
||||
caption = row["caption"]
|
||||
value = np.array(model.encode([caption])).reshape(-1)
|
||||
|
||||
if key not in embeddings.keys():
|
||||
embeddings[key] = [value]
|
||||
else:
|
||||
embeddings[key].append(value)
|
||||
|
||||
pbar.update()
|
||||
|
||||
for key in embeddings:
|
||||
dump[key] = np.stack(embeddings[key])
|
||||
|
||||
embeddings = dump
|
||||
|
||||
with open(output, "wb") as f:
|
||||
pickle.dump(embeddings, f)
|
||||
|
||||
def extract_sbert(self,
|
||||
input_json: str,
|
||||
output: str):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import json
|
||||
import torch
|
||||
from h5py import File
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
data = json.load(open(input_json))["audios"]
|
||||
with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store:
|
||||
for sample in data:
|
||||
audio_id = sample["audio_id"]
|
||||
for cap in sample["captions"]:
|
||||
cap_id = cap["cap_id"]
|
||||
store[f"{audio_id}_{cap_id}"] = model.encode(cap["caption"])
|
||||
pbar.update()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(EmbeddingExtractor)
|
||||
34
audio_to_text/captioning/utils/bert/create_word_embedding.py
Normal file
34
audio_to_text/captioning/utils/bert/create_word_embedding.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
from bert_serving.client import BertClient
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import fire
|
||||
import torch
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
from utils.build_vocab import Vocabulary
|
||||
|
||||
def main(vocab_file: str, output: str, server_hostname: str):
|
||||
client = BertClient(ip=server_hostname)
|
||||
vocabulary = torch.load(vocab_file)
|
||||
vocab_size = len(vocabulary)
|
||||
|
||||
fake_embedding = client.encode(["test"]).reshape(-1)
|
||||
embed_size = fake_embedding.shape[0]
|
||||
|
||||
print("Encoding words into embeddings with size: ", embed_size)
|
||||
|
||||
embeddings = np.empty((vocab_size, embed_size))
|
||||
for i in tqdm(range(len(embeddings)), ascii=True):
|
||||
embeddings[i] = client.encode([vocabulary.idx2word[i]])
|
||||
np.save(output, embeddings)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(main)
|
||||
|
||||
|
||||
153
audio_to_text/captioning/utils/build_vocab.py
Normal file
153
audio_to_text/captioning/utils/build_vocab.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import pickle
|
||||
from collections import Counter
|
||||
import re
|
||||
import fire
|
||||
|
||||
|
||||
class Vocabulary(object):
|
||||
"""Simple vocabulary wrapper."""
|
||||
def __init__(self):
|
||||
self.word2idx = {}
|
||||
self.idx2word = {}
|
||||
self.idx = 0
|
||||
|
||||
def add_word(self, word):
|
||||
if not word in self.word2idx:
|
||||
self.word2idx[word] = self.idx
|
||||
self.idx2word[self.idx] = word
|
||||
self.idx += 1
|
||||
|
||||
def __call__(self, word):
|
||||
if not word in self.word2idx:
|
||||
return self.word2idx["<unk>"]
|
||||
return self.word2idx[word]
|
||||
|
||||
def __getitem__(self, word_id):
|
||||
return self.idx2word[word_id]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.word2idx)
|
||||
|
||||
|
||||
def build_vocab(input_json: str,
|
||||
threshold: int,
|
||||
keep_punctuation: bool,
|
||||
host_address: str,
|
||||
character_level: bool = False,
|
||||
zh: bool = True ):
|
||||
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
|
||||
|
||||
Args:
|
||||
input_json(string): Preprossessed json file. Structure like this:
|
||||
{
|
||||
'audios': [
|
||||
{
|
||||
'audio_id': 'xxx',
|
||||
'captions': [
|
||||
{
|
||||
'caption': 'xxx',
|
||||
'cap_id': 'xxx'
|
||||
}
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
threshold (int): Threshold to drop all words with counts < threshold
|
||||
keep_punctuation (bool): Includes or excludes punctuation.
|
||||
|
||||
Returns:
|
||||
vocab (Vocab): Object with the processed vocabulary
|
||||
"""
|
||||
data = json.load(open(input_json, "r"))["audios"]
|
||||
counter = Counter()
|
||||
pretokenized = "tokens" in data[0]["captions"][0]
|
||||
|
||||
if zh:
|
||||
from nltk.parse.corenlp import CoreNLPParser
|
||||
from zhon.hanzi import punctuation
|
||||
if not pretokenized:
|
||||
parser = CoreNLPParser(host_address)
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
if pretokenized:
|
||||
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
||||
else:
|
||||
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
||||
# Remove all punctuations
|
||||
if not keep_punctuation:
|
||||
caption = re.sub("[{}]".format(punctuation), "", caption)
|
||||
if character_level:
|
||||
tokens = list(caption)
|
||||
else:
|
||||
tokens = list(parser.tokenize(caption))
|
||||
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
|
||||
counter.update(tokens)
|
||||
else:
|
||||
if pretokenized:
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
||||
counter.update(tokens)
|
||||
else:
|
||||
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
||||
captions = {}
|
||||
for audio_idx in range(len(data)):
|
||||
audio_id = data[audio_idx]["audio_id"]
|
||||
captions[audio_id] = []
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
||||
captions[audio_id].append({
|
||||
"audio_id": audio_id,
|
||||
"id": cap_idx,
|
||||
"caption": caption
|
||||
})
|
||||
tokenizer = PTBTokenizer()
|
||||
captions = tokenizer.tokenize(captions)
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
audio_id = data[audio_idx]["audio_id"]
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
tokens = captions[audio_id][cap_idx]
|
||||
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
|
||||
counter.update(tokens.split(" "))
|
||||
|
||||
if not pretokenized:
|
||||
json.dump({ "audios": data }, open(input_json, "w"), indent=4, ensure_ascii=not zh)
|
||||
words = [word for word, cnt in counter.items() if cnt >= threshold]
|
||||
|
||||
# Create a vocab wrapper and add some special tokens.
|
||||
vocab = Vocabulary()
|
||||
vocab.add_word("<pad>")
|
||||
vocab.add_word("<start>")
|
||||
vocab.add_word("<end>")
|
||||
vocab.add_word("<unk>")
|
||||
|
||||
# Add the words to the vocabulary.
|
||||
for word in words:
|
||||
vocab.add_word(word)
|
||||
return vocab
|
||||
|
||||
|
||||
def process(input_json: str,
|
||||
output_file: str,
|
||||
threshold: int = 1,
|
||||
keep_punctuation: bool = False,
|
||||
character_level: bool = False,
|
||||
host_address: str = "http://localhost:9000",
|
||||
zh: bool = False):
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
logging.info("Build Vocab")
|
||||
vocabulary = build_vocab(
|
||||
input_json=input_json, threshold=threshold, keep_punctuation=keep_punctuation,
|
||||
host_address=host_address, character_level=character_level, zh=zh)
|
||||
pickle.dump(vocabulary, open(output_file, "wb"))
|
||||
logging.info("Total vocabulary size: {}".format(len(vocabulary)))
|
||||
logging.info("Saved vocab to '{}'".format(output_file))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(process)
|
||||
150
audio_to_text/captioning/utils/build_vocab_ltp.py
Normal file
150
audio_to_text/captioning/utils/build_vocab_ltp.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import pickle
|
||||
from collections import Counter
|
||||
import re
|
||||
import fire
|
||||
|
||||
class Vocabulary(object):
|
||||
"""Simple vocabulary wrapper."""
|
||||
def __init__(self):
|
||||
self.word2idx = {}
|
||||
self.idx2word = {}
|
||||
self.idx = 0
|
||||
|
||||
def add_word(self, word):
|
||||
if not word in self.word2idx:
|
||||
self.word2idx[word] = self.idx
|
||||
self.idx2word[self.idx] = word
|
||||
self.idx += 1
|
||||
|
||||
def __call__(self, word):
|
||||
if not word in self.word2idx:
|
||||
return self.word2idx["<unk>"]
|
||||
return self.word2idx[word]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.word2idx)
|
||||
|
||||
def build_vocab(input_json: str,
|
||||
output_json: str,
|
||||
threshold: int,
|
||||
keep_punctuation: bool,
|
||||
character_level: bool = False,
|
||||
zh: bool = True ):
|
||||
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
|
||||
|
||||
Args:
|
||||
input_json(string): Preprossessed json file. Structure like this:
|
||||
{
|
||||
'audios': [
|
||||
{
|
||||
'audio_id': 'xxx',
|
||||
'captions': [
|
||||
{
|
||||
'caption': 'xxx',
|
||||
'cap_id': 'xxx'
|
||||
}
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
threshold (int): Threshold to drop all words with counts < threshold
|
||||
keep_punctuation (bool): Includes or excludes punctuation.
|
||||
|
||||
Returns:
|
||||
vocab (Vocab): Object with the processed vocabulary
|
||||
"""
|
||||
data = json.load(open(input_json, "r"))["audios"]
|
||||
counter = Counter()
|
||||
pretokenized = "tokens" in data[0]["captions"][0]
|
||||
|
||||
if zh:
|
||||
from ltp import LTP
|
||||
from zhon.hanzi import punctuation
|
||||
if not pretokenized:
|
||||
parser = LTP("base")
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
if pretokenized:
|
||||
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
||||
else:
|
||||
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
||||
if character_level:
|
||||
tokens = list(caption)
|
||||
else:
|
||||
tokens, _ = parser.seg([caption])
|
||||
tokens = tokens[0]
|
||||
# Remove all punctuations
|
||||
if not keep_punctuation:
|
||||
tokens = [token for token in tokens if token not in punctuation]
|
||||
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
|
||||
counter.update(tokens)
|
||||
else:
|
||||
if pretokenized:
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
||||
counter.update(tokens)
|
||||
else:
|
||||
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
||||
captions = {}
|
||||
for audio_idx in range(len(data)):
|
||||
audio_id = data[audio_idx]["audio_id"]
|
||||
captions[audio_id] = []
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
||||
captions[audio_id].append({
|
||||
"audio_id": audio_id,
|
||||
"id": cap_idx,
|
||||
"caption": caption
|
||||
})
|
||||
tokenizer = PTBTokenizer()
|
||||
captions = tokenizer.tokenize(captions)
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
audio_id = data[audio_idx]["audio_id"]
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
tokens = captions[audio_id][cap_idx]
|
||||
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
|
||||
counter.update(tokens.split(" "))
|
||||
|
||||
if not pretokenized:
|
||||
if output_json is None:
|
||||
output_json = input_json
|
||||
json.dump({ "audios": data }, open(output_json, "w"), indent=4, ensure_ascii=not zh)
|
||||
words = [word for word, cnt in counter.items() if cnt >= threshold]
|
||||
|
||||
# Create a vocab wrapper and add some special tokens.
|
||||
vocab = Vocabulary()
|
||||
vocab.add_word("<pad>")
|
||||
vocab.add_word("<start>")
|
||||
vocab.add_word("<end>")
|
||||
vocab.add_word("<unk>")
|
||||
|
||||
# Add the words to the vocabulary.
|
||||
for word in words:
|
||||
vocab.add_word(word)
|
||||
return vocab
|
||||
|
||||
def process(input_json: str,
|
||||
output_file: str,
|
||||
output_json: str = None,
|
||||
threshold: int = 1,
|
||||
keep_punctuation: bool = False,
|
||||
character_level: bool = False,
|
||||
zh: bool = True):
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
logging.info("Build Vocab")
|
||||
vocabulary = build_vocab(
|
||||
input_json=input_json, output_json=output_json, threshold=threshold,
|
||||
keep_punctuation=keep_punctuation, character_level=character_level, zh=zh)
|
||||
pickle.dump(vocabulary, open(output_file, "wb"))
|
||||
logging.info("Total vocabulary size: {}".format(len(vocabulary)))
|
||||
logging.info("Saved vocab to '{}'".format(output_file))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(process)
|
||||
152
audio_to_text/captioning/utils/build_vocab_spacy.py
Normal file
152
audio_to_text/captioning/utils/build_vocab_spacy.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import pickle
|
||||
from collections import Counter
|
||||
import re
|
||||
import fire
|
||||
|
||||
class Vocabulary(object):
|
||||
"""Simple vocabulary wrapper."""
|
||||
def __init__(self):
|
||||
self.word2idx = {}
|
||||
self.idx2word = {}
|
||||
self.idx = 0
|
||||
|
||||
def add_word(self, word):
|
||||
if not word in self.word2idx:
|
||||
self.word2idx[word] = self.idx
|
||||
self.idx2word[self.idx] = word
|
||||
self.idx += 1
|
||||
|
||||
def __call__(self, word):
|
||||
if not word in self.word2idx:
|
||||
return self.word2idx["<unk>"]
|
||||
return self.word2idx[word]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.word2idx)
|
||||
|
||||
|
||||
def build_vocab(input_json: str,
|
||||
output_json: str,
|
||||
threshold: int,
|
||||
keep_punctuation: bool,
|
||||
host_address: str,
|
||||
character_level: bool = False,
|
||||
retokenize: bool = True,
|
||||
zh: bool = True ):
|
||||
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
|
||||
|
||||
Args:
|
||||
input_json(string): Preprossessed json file. Structure like this:
|
||||
{
|
||||
'audios': [
|
||||
{
|
||||
'audio_id': 'xxx',
|
||||
'captions': [
|
||||
{
|
||||
'caption': 'xxx',
|
||||
'cap_id': 'xxx'
|
||||
}
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
threshold (int): Threshold to drop all words with counts < threshold
|
||||
keep_punctuation (bool): Includes or excludes punctuation.
|
||||
|
||||
Returns:
|
||||
vocab (Vocab): Object with the processed vocabulary
|
||||
"""
|
||||
data = json.load(open(input_json, "r"))["audios"]
|
||||
counter = Counter()
|
||||
if retokenize:
|
||||
pretokenized = False
|
||||
else:
|
||||
pretokenized = "tokens" in data[0]["captions"][0]
|
||||
|
||||
if zh:
|
||||
from nltk.parse.corenlp import CoreNLPParser
|
||||
from zhon.hanzi import punctuation
|
||||
if not pretokenized:
|
||||
parser = CoreNLPParser(host_address)
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
if pretokenized:
|
||||
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
||||
else:
|
||||
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
||||
# Remove all punctuations
|
||||
if not keep_punctuation:
|
||||
caption = re.sub("[{}]".format(punctuation), "", caption)
|
||||
if character_level:
|
||||
tokens = list(caption)
|
||||
else:
|
||||
tokens = list(parser.tokenize(caption))
|
||||
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
|
||||
counter.update(tokens)
|
||||
else:
|
||||
if pretokenized:
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
||||
counter.update(tokens)
|
||||
else:
|
||||
import spacy
|
||||
tokenizer = spacy.load("en_core_web_sm", disable=["parser", "ner"])
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
captions = data[audio_idx]["captions"]
|
||||
for cap_idx in range(len(captions)):
|
||||
caption = captions[cap_idx]["caption"]
|
||||
doc = tokenizer(caption)
|
||||
tokens = " ".join([str(token).lower() for token in doc])
|
||||
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
|
||||
counter.update(tokens.split(" "))
|
||||
|
||||
if not pretokenized:
|
||||
if output_json is None:
|
||||
json.dump({ "audios": data }, open(input_json, "w"),
|
||||
indent=4, ensure_ascii=not zh)
|
||||
else:
|
||||
json.dump({ "audios": data }, open(output_json, "w"),
|
||||
indent=4, ensure_ascii=not zh)
|
||||
|
||||
words = [word for word, cnt in counter.items() if cnt >= threshold]
|
||||
|
||||
# Create a vocab wrapper and add some special tokens.
|
||||
vocab = Vocabulary()
|
||||
vocab.add_word("<pad>")
|
||||
vocab.add_word("<start>")
|
||||
vocab.add_word("<end>")
|
||||
vocab.add_word("<unk>")
|
||||
|
||||
# Add the words to the vocabulary.
|
||||
for word in words:
|
||||
vocab.add_word(word)
|
||||
return vocab
|
||||
|
||||
def process(input_json: str,
|
||||
output_file: str,
|
||||
output_json: str = None,
|
||||
threshold: int = 1,
|
||||
keep_punctuation: bool = False,
|
||||
character_level: bool = False,
|
||||
retokenize: bool = False,
|
||||
host_address: str = "http://localhost:9000",
|
||||
zh: bool = True):
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
logging.info("Build Vocab")
|
||||
vocabulary = build_vocab(
|
||||
input_json=input_json, output_json=output_json, threshold=threshold,
|
||||
keep_punctuation=keep_punctuation, host_address=host_address,
|
||||
character_level=character_level, retokenize=retokenize, zh=zh)
|
||||
pickle.dump(vocabulary, open(output_file, "wb"))
|
||||
logging.info("Total vocabulary size: {}".format(len(vocabulary)))
|
||||
logging.info("Saved vocab to '{}'".format(output_file))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(process)
|
||||
182
audio_to_text/captioning/utils/eval_round_robin.py
Normal file
182
audio_to_text/captioning/utils/eval_round_robin.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import copy
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import fire
|
||||
|
||||
|
||||
def evaluate_annotation(key2refs, scorer):
|
||||
if scorer.method() == "Bleu":
|
||||
scores = np.array([ 0.0 for n in range(4) ])
|
||||
else:
|
||||
scores = 0
|
||||
num_cap_per_audio = len(next(iter(key2refs.values())))
|
||||
|
||||
for i in range(num_cap_per_audio):
|
||||
if i > 0:
|
||||
for key in key2refs:
|
||||
key2refs[key].insert(0, res[key][0])
|
||||
res = { key: [refs.pop(),] for key, refs in key2refs.items() }
|
||||
score, _ = scorer.compute_score(key2refs, res)
|
||||
|
||||
if scorer.method() == "Bleu":
|
||||
scores += np.array(score)
|
||||
else:
|
||||
scores += score
|
||||
|
||||
score = scores / num_cap_per_audio
|
||||
return score
|
||||
|
||||
def evaluate_prediction(key2pred, key2refs, scorer):
|
||||
if scorer.method() == "Bleu":
|
||||
scores = np.array([ 0.0 for n in range(4) ])
|
||||
else:
|
||||
scores = 0
|
||||
num_cap_per_audio = len(next(iter(key2refs.values())))
|
||||
|
||||
for i in range(num_cap_per_audio):
|
||||
key2refs_i = {}
|
||||
for key, refs in key2refs.items():
|
||||
key2refs_i[key] = refs[:i] + refs[i+1:]
|
||||
score, _ = scorer.compute_score(key2refs_i, key2pred)
|
||||
|
||||
if scorer.method() == "Bleu":
|
||||
scores += np.array(score)
|
||||
else:
|
||||
scores += score
|
||||
|
||||
score = scores / num_cap_per_audio
|
||||
return score
|
||||
|
||||
|
||||
class Evaluator(object):
|
||||
|
||||
def eval_annotation(self, annotation, output):
|
||||
captions = json.load(open(annotation, "r"))["audios"]
|
||||
|
||||
key2refs = {}
|
||||
for audio_idx in range(len(captions)):
|
||||
audio_id = captions[audio_idx]["audio_id"]
|
||||
key2refs[audio_id] = []
|
||||
for caption in captions[audio_idx]["captions"]:
|
||||
key2refs[audio_id].append(caption["caption"])
|
||||
|
||||
from fense.fense import Fense
|
||||
scores = {}
|
||||
scorer = Fense()
|
||||
scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer)
|
||||
|
||||
refs4eval = {}
|
||||
for key, refs in key2refs.items():
|
||||
refs4eval[key] = []
|
||||
for idx, ref in enumerate(refs):
|
||||
refs4eval[key].append({
|
||||
"audio_id": key,
|
||||
"id": idx,
|
||||
"caption": ref
|
||||
})
|
||||
|
||||
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
||||
|
||||
tokenizer = PTBTokenizer()
|
||||
key2refs = tokenizer.tokenize(refs4eval)
|
||||
|
||||
|
||||
from pycocoevalcap.bleu.bleu import Bleu
|
||||
from pycocoevalcap.cider.cider import Cider
|
||||
from pycocoevalcap.rouge.rouge import Rouge
|
||||
from pycocoevalcap.meteor.meteor import Meteor
|
||||
from pycocoevalcap.spice.spice import Spice
|
||||
|
||||
|
||||
scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()]
|
||||
for scorer in scorers:
|
||||
scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer)
|
||||
|
||||
spider = 0
|
||||
with open(output, "w") as f:
|
||||
for name, score in scores.items():
|
||||
if name == "Bleu":
|
||||
for n in range(4):
|
||||
f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n]))
|
||||
else:
|
||||
f.write("{}: {:6.3f}\n".format(name, score))
|
||||
if name in ["CIDEr", "SPICE"]:
|
||||
spider += score
|
||||
f.write("SPIDEr: {:6.3f}\n".format(spider / 2))
|
||||
|
||||
def eval_prediction(self, prediction, annotation, output):
|
||||
ref_captions = json.load(open(annotation, "r"))["audios"]
|
||||
|
||||
key2refs = {}
|
||||
for audio_idx in range(len(ref_captions)):
|
||||
audio_id = ref_captions[audio_idx]["audio_id"]
|
||||
key2refs[audio_id] = []
|
||||
for caption in ref_captions[audio_idx]["captions"]:
|
||||
key2refs[audio_id].append(caption["caption"])
|
||||
|
||||
pred_captions = json.load(open(prediction, "r"))["predictions"]
|
||||
|
||||
key2pred = {}
|
||||
for audio_idx in range(len(pred_captions)):
|
||||
item = pred_captions[audio_idx]
|
||||
audio_id = item["filename"]
|
||||
key2pred[audio_id] = [item["tokens"]]
|
||||
|
||||
from fense.fense import Fense
|
||||
scores = {}
|
||||
scorer = Fense()
|
||||
scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer)
|
||||
|
||||
refs4eval = {}
|
||||
for key, refs in key2refs.items():
|
||||
refs4eval[key] = []
|
||||
for idx, ref in enumerate(refs):
|
||||
refs4eval[key].append({
|
||||
"audio_id": key,
|
||||
"id": idx,
|
||||
"caption": ref
|
||||
})
|
||||
|
||||
preds4eval = {}
|
||||
for key, preds in key2pred.items():
|
||||
preds4eval[key] = []
|
||||
for idx, pred in enumerate(preds):
|
||||
preds4eval[key].append({
|
||||
"audio_id": key,
|
||||
"id": idx,
|
||||
"caption": pred
|
||||
})
|
||||
|
||||
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
||||
|
||||
tokenizer = PTBTokenizer()
|
||||
key2refs = tokenizer.tokenize(refs4eval)
|
||||
key2pred = tokenizer.tokenize(preds4eval)
|
||||
|
||||
|
||||
from pycocoevalcap.bleu.bleu import Bleu
|
||||
from pycocoevalcap.cider.cider import Cider
|
||||
from pycocoevalcap.rouge.rouge import Rouge
|
||||
from pycocoevalcap.meteor.meteor import Meteor
|
||||
from pycocoevalcap.spice.spice import Spice
|
||||
|
||||
scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()]
|
||||
for scorer in scorers:
|
||||
scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer)
|
||||
|
||||
spider = 0
|
||||
with open(output, "w") as f:
|
||||
for name, score in scores.items():
|
||||
if name == "Bleu":
|
||||
for n in range(4):
|
||||
f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n]))
|
||||
else:
|
||||
f.write("{}: {:6.3f}\n".format(name, score))
|
||||
if name in ["CIDEr", "SPICE"]:
|
||||
spider += score
|
||||
f.write("SPIDEr: {:6.3f}\n".format(spider / 2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Evaluator)
|
||||
@@ -0,0 +1,50 @@
|
||||
# coding=utf-8
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from gensim.models import FastText
|
||||
from tqdm import tqdm
|
||||
import fire
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.getcwd())
|
||||
from utils.build_vocab import Vocabulary
|
||||
|
||||
def create_embedding(caption_file: str,
|
||||
vocab_file: str,
|
||||
embed_size: int,
|
||||
output: str,
|
||||
**fasttext_kwargs):
|
||||
caption_df = pd.read_json(caption_file)
|
||||
caption_df["tokens"] = caption_df["tokens"].apply(lambda x: ["<start>"] + [token for token in x] + ["<end>"])
|
||||
|
||||
sentences = list(caption_df["tokens"].values)
|
||||
vocabulary = torch.load(vocab_file, map_location="cpu")
|
||||
|
||||
epochs = fasttext_kwargs.get("epochs", 10)
|
||||
model = FastText(size=embed_size, min_count=1, **fasttext_kwargs)
|
||||
model.build_vocab(sentences=sentences)
|
||||
model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs)
|
||||
|
||||
word_embeddings = np.zeros((len(vocabulary), embed_size))
|
||||
|
||||
with tqdm(total=len(vocabulary), ascii=True) as pbar:
|
||||
for word, idx in vocabulary.word2idx.items():
|
||||
if word == "<pad>" or word == "<unk>":
|
||||
continue
|
||||
word_embeddings[idx] = model.wv[word]
|
||||
pbar.update()
|
||||
|
||||
np.save(output, word_embeddings)
|
||||
|
||||
print("Finish writing fasttext embeddings to " + output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(create_embedding)
|
||||
|
||||
|
||||
|
||||
128
audio_to_text/captioning/utils/lr_scheduler.py
Normal file
128
audio_to_text/captioning/utils/lr_scheduler.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class ExponentialDecayScheduler(torch.optim.lr_scheduler._LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, total_iters, final_lrs,
|
||||
warmup_iters=3000, last_epoch=-1, verbose=False):
|
||||
self.total_iters = total_iters
|
||||
self.final_lrs = final_lrs
|
||||
if not isinstance(self.final_lrs, list) and not isinstance(
|
||||
self.final_lrs, tuple):
|
||||
self.final_lrs = [self.final_lrs] * len(optimizer.param_groups)
|
||||
self.warmup_iters = warmup_iters
|
||||
self.bases = [0.0,] * len(optimizer.param_groups)
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
for i, (base_lr, final_lr) in enumerate(zip(self.base_lrs, self.final_lrs)):
|
||||
base = (final_lr / base_lr) ** (1 / (
|
||||
self.total_iters - self.warmup_iters))
|
||||
self.bases[i] = base
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
warmup_coeff = 1.0
|
||||
current_iter = self._step_count
|
||||
if current_iter < self.warmup_iters:
|
||||
warmup_coeff = current_iter / self.warmup_iters
|
||||
current_lrs = []
|
||||
# if not self.linear_warmup:
|
||||
# for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs, self.bases):
|
||||
# # current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr))
|
||||
# current_lr = warmup_coeff * base_lr * (base ** (current_iter - self.warmup_iters))
|
||||
# current_lrs.append(current_lr)
|
||||
# else:
|
||||
for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs,
|
||||
self.bases):
|
||||
if current_iter <= self.warmup_iters:
|
||||
current_lr = warmup_coeff * base_lr
|
||||
else:
|
||||
# current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr))
|
||||
current_lr = base_lr * (base ** (current_iter - self.warmup_iters))
|
||||
current_lrs.append(current_lr)
|
||||
return current_lrs
|
||||
|
||||
def get_lr(self):
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
|
||||
class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, model_size=512, factor=1, warmup_iters=3000,
|
||||
last_epoch=-1, verbose=False):
|
||||
self.model_size = model_size
|
||||
self.warmup_iters = warmup_iters
|
||||
# self.factors = [group["lr"] / (self.model_size ** (-0.5) * self.warmup_iters ** (-0.5)) for group in optimizer.param_groups]
|
||||
self.factor = factor
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
current_iter = self._step_count
|
||||
current_lrs = []
|
||||
for _ in self.base_lrs:
|
||||
current_lr = self.factor * \
|
||||
(self.model_size ** (-0.5) * min(current_iter ** (-0.5),
|
||||
current_iter * self.warmup_iters ** (-1.5)))
|
||||
current_lrs.append(current_lr)
|
||||
return current_lrs
|
||||
|
||||
def get_lr(self):
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
|
||||
class CosineWithWarmup(torch.optim.lr_scheduler._LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, total_iters, warmup_iters,
|
||||
num_cycles=0.5, last_epoch=-1, verbose=False):
|
||||
self.total_iters = total_iters
|
||||
self.warmup_iters = warmup_iters
|
||||
self.num_cycles = num_cycles
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def lr_lambda(self, iteration):
|
||||
if iteration < self.warmup_iters:
|
||||
return float(iteration) / float(max(1, self.warmup_iters))
|
||||
progress = float(iteration - self.warmup_iters) / float(max(1,
|
||||
self.total_iters - self.warmup_iters))
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(
|
||||
self.num_cycles) * 2.0 * progress)))
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
current_iter = self._step_count
|
||||
current_lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
current_lr = base_lr * self.lr_lambda(current_iter)
|
||||
current_lrs.append(current_lr)
|
||||
return current_lrs
|
||||
|
||||
def get_lr(self):
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = torch.optim.Adam(model.parameters(), 5e-4)
|
||||
epochs = 25
|
||||
iters = 600
|
||||
scheduler = CosineWithWarmup(optimizer, 600 * 25, 600 * 5,)
|
||||
# scheduler = ExponentialDecayScheduler(optimizer, 600 * 25, 5e-7, 600 * 5)
|
||||
criterion = torch.nn.MSELoss()
|
||||
lrs = []
|
||||
for epoch in range(1, epochs + 1):
|
||||
for iteration in range(1, iters + 1):
|
||||
optimizer.zero_grad()
|
||||
x = torch.randn(4, 10)
|
||||
y = torch.randn(4, 5)
|
||||
loss = criterion(model(x), y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
# print(f"lr: {scheduler.get_last_lr()}")
|
||||
# lrs.append(scheduler.get_last_lr())
|
||||
lrs.append(optimizer.param_groups[0]["lr"])
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(list(range(1, len(lrs) + 1)), lrs, '-o', markersize=1)
|
||||
# plt.legend(loc="best")
|
||||
plt.xlabel("Iteration")
|
||||
plt.ylabel("LR")
|
||||
|
||||
plt.savefig("lr_curve.png", dpi=100)
|
||||
110
audio_to_text/captioning/utils/model_eval_diff.py
Normal file
110
audio_to_text/captioning/utils/model_eval_diff.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import fire
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
|
||||
def coco_score(refs, pred, scorer):
|
||||
if scorer.method() == "Bleu":
|
||||
scores = np.array([ 0.0 for n in range(4) ])
|
||||
else:
|
||||
scores = 0
|
||||
num_cap_per_audio = len(refs[list(refs.keys())[0]])
|
||||
|
||||
for i in range(num_cap_per_audio):
|
||||
if i > 0:
|
||||
for key in refs:
|
||||
refs[key].insert(0, res[key][0])
|
||||
res = {key: [refs[key].pop(),] for key in refs}
|
||||
score, _ = scorer.compute_score(refs, pred)
|
||||
|
||||
if scorer.method() == "Bleu":
|
||||
scores += np.array(score)
|
||||
else:
|
||||
scores += score
|
||||
|
||||
score = scores / num_cap_per_audio
|
||||
|
||||
for key in refs:
|
||||
refs[key].insert(0, res[key][0])
|
||||
score_allref, _ = scorer.compute_score(refs, pred)
|
||||
diff = score_allref - score
|
||||
return diff
|
||||
|
||||
def embedding_score(refs, pred, scorer):
|
||||
|
||||
num_cap_per_audio = len(refs[list(refs.keys())[0]])
|
||||
scores = 0
|
||||
|
||||
for i in range(num_cap_per_audio):
|
||||
res = {key: [refs[key][i],] for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
|
||||
refs_i = {key: np.concatenate([refs[key][:i], refs[key][i+1:]]) for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
|
||||
score, _ = scorer.compute_score(refs_i, pred)
|
||||
|
||||
scores += score
|
||||
|
||||
score = scores / num_cap_per_audio
|
||||
|
||||
score_allref, _ = scorer.compute_score(refs, pred)
|
||||
diff = score_allref - score
|
||||
return diff
|
||||
|
||||
def main(output_file, eval_caption_file, eval_embedding_file, output, zh=False):
|
||||
output_df = pd.read_json(output_file)
|
||||
output_df["key"] = output_df["filename"].apply(lambda x: os.path.splitext(os.path.basename(x))[0])
|
||||
pred = output_df.groupby("key")["tokens"].apply(list).to_dict()
|
||||
|
||||
label_df = pd.read_json(eval_caption_file)
|
||||
if zh:
|
||||
refs = label_df.groupby("key")["tokens"].apply(list).to_dict()
|
||||
else:
|
||||
refs = label_df.groupby("key")["caption"].apply(list).to_dict()
|
||||
|
||||
from pycocoevalcap.bleu.bleu import Bleu
|
||||
from pycocoevalcap.cider.cider import Cider
|
||||
from pycocoevalcap.rouge.rouge import Rouge
|
||||
|
||||
scorer = Bleu(zh=zh)
|
||||
bleu_scores = coco_score(copy.deepcopy(refs), pred, scorer)
|
||||
scorer = Cider(zh=zh)
|
||||
cider_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
||||
scorer = Rouge(zh=zh)
|
||||
rouge_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
||||
|
||||
if not zh:
|
||||
from pycocoevalcap.meteor.meteor import Meteor
|
||||
scorer = Meteor()
|
||||
meteor_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
||||
|
||||
from pycocoevalcap.spice.spice import Spice
|
||||
scorer = Spice()
|
||||
spice_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
||||
|
||||
# from audiocaptioneval.sentbert.sentencebert import SentenceBert
|
||||
# scorer = SentenceBert(zh=zh)
|
||||
# with open(eval_embedding_file, "rb") as f:
|
||||
# ref_embeddings = pickle.load(f)
|
||||
|
||||
# sent_bert = embedding_score(ref_embeddings, pred, scorer)
|
||||
|
||||
with open(output, "w") as f:
|
||||
f.write("Diff:\n")
|
||||
for n in range(4):
|
||||
f.write("BLEU-{}: {:6.3f}\n".format(n+1, bleu_scores[n]))
|
||||
f.write("CIDEr: {:6.3f}\n".format(cider_score))
|
||||
f.write("ROUGE: {:6.3f}\n".format(rouge_score))
|
||||
if not zh:
|
||||
f.write("Meteor: {:6.3f}\n".format(meteor_score))
|
||||
f.write("SPICE: {:6.3f}\n".format(spice_score))
|
||||
# f.write("SentenceBert: {:6.3f}\n".format(sent_bert))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
49
audio_to_text/captioning/utils/predict_nn.py
Normal file
49
audio_to_text/captioning/utils/predict_nn.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import json
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from h5py import File
|
||||
import sklearn.metrics
|
||||
|
||||
random.seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_feature", type=str)
|
||||
parser.add_argument("train_corpus", type=str)
|
||||
parser.add_argument("pred_feature", type=str)
|
||||
parser.add_argument("output_json", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
train_embs = []
|
||||
train_idx_to_audioid = []
|
||||
with File(args.train_feature, "r") as store:
|
||||
for audio_id, embedding in tqdm(store.items(), ascii=True):
|
||||
train_embs.append(embedding[()])
|
||||
train_idx_to_audioid.append(audio_id)
|
||||
|
||||
train_annotation = json.load(open(args.train_corpus, "r"))["audios"]
|
||||
train_audioid_to_tokens = {}
|
||||
for item in train_annotation:
|
||||
audio_id = item["audio_id"]
|
||||
train_audioid_to_tokens[audio_id] = [cap_item["tokens"] for cap_item in item["captions"]]
|
||||
train_embs = np.stack(train_embs)
|
||||
|
||||
|
||||
pred_data = []
|
||||
pred_embs = []
|
||||
pred_idx_to_audioids = []
|
||||
with File(args.pred_feature, "r") as store:
|
||||
for audio_id, embedding in tqdm(store.items(), ascii=True):
|
||||
pred_embs.append(embedding[()])
|
||||
pred_idx_to_audioids.append(audio_id)
|
||||
pred_embs = np.stack(pred_embs)
|
||||
|
||||
similarity = sklearn.metrics.pairwise.cosine_similarity(pred_embs, train_embs)
|
||||
for idx, audio_id in enumerate(pred_idx_to_audioids):
|
||||
train_idx = similarity[idx].argmax()
|
||||
pred_data.append({
|
||||
"filename": audio_id,
|
||||
"tokens": random.choice(train_audioid_to_tokens[train_idx_to_audioid[train_idx]])
|
||||
})
|
||||
json.dump({"predictions": pred_data}, open(args.output_json, "w"), ensure_ascii=False, indent=4)
|
||||
18
audio_to_text/captioning/utils/remove_optimizer.py
Normal file
18
audio_to_text/captioning/utils/remove_optimizer.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
|
||||
def main(checkpoint):
|
||||
state_dict = torch.load(checkpoint, map_location="cpu")
|
||||
if "optimizer" in state_dict:
|
||||
del state_dict["optimizer"]
|
||||
if "lr_scheduler" in state_dict:
|
||||
del state_dict["lr_scheduler"]
|
||||
torch.save(state_dict, checkpoint)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("checkpoint", type=str)
|
||||
args = parser.parse_args()
|
||||
main(args.checkpoint)
|
||||
37
audio_to_text/captioning/utils/report_results.py
Normal file
37
audio_to_text/captioning/utils/report_results.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", help="input filename", type=str, nargs="+")
|
||||
parser.add_argument("--output", help="output result file", default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
scores = {}
|
||||
for path in args.input:
|
||||
with open(path, "r") as reader:
|
||||
for line in reader.readlines():
|
||||
metric, score = line.strip().split(": ")
|
||||
score = float(score)
|
||||
if metric not in scores:
|
||||
scores[metric] = []
|
||||
scores[metric].append(score)
|
||||
|
||||
if len(scores) == 0:
|
||||
print("No experiment directory found, wrong path?")
|
||||
exit(1)
|
||||
|
||||
with open(args.output, "w") as writer:
|
||||
print("Average results: ", file=writer)
|
||||
for metric, score in scores.items():
|
||||
score = np.array(score)
|
||||
mean = np.mean(score)
|
||||
std = np.std(score)
|
||||
print(f"{metric}: {mean:.3f} (±{std:.3f})", file=writer)
|
||||
print("", file=writer)
|
||||
print("Best results: ", file=writer)
|
||||
for metric, score in scores.items():
|
||||
score = np.max(score)
|
||||
print(f"{metric}: {score:.3f}", file=writer)
|
||||
86
audio_to_text/captioning/utils/tokenize_caption.py
Normal file
86
audio_to_text/captioning/utils/tokenize_caption.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import fire
|
||||
|
||||
|
||||
def tokenize_caption(input_json: str,
|
||||
keep_punctuation: bool = False,
|
||||
host_address: str = None,
|
||||
character_level: bool = False,
|
||||
zh: bool = True,
|
||||
output_json: str = None):
|
||||
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
|
||||
|
||||
Args:
|
||||
input_json(string): Preprossessed json file. Structure like this:
|
||||
{
|
||||
'audios': [
|
||||
{
|
||||
'audio_id': 'xxx',
|
||||
'captions': [
|
||||
{
|
||||
'caption': 'xxx',
|
||||
'cap_id': 'xxx'
|
||||
}
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
threshold (int): Threshold to drop all words with counts < threshold
|
||||
keep_punctuation (bool): Includes or excludes punctuation.
|
||||
|
||||
Returns:
|
||||
vocab (Vocab): Object with the processed vocabulary
|
||||
"""
|
||||
data = json.load(open(input_json, "r"))["audios"]
|
||||
|
||||
if zh:
|
||||
from nltk.parse.corenlp import CoreNLPParser
|
||||
from zhon.hanzi import punctuation
|
||||
parser = CoreNLPParser(host_address)
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
||||
# Remove all punctuations
|
||||
if not keep_punctuation:
|
||||
caption = re.sub("[{}]".format(punctuation), "", caption)
|
||||
if character_level:
|
||||
tokens = list(caption)
|
||||
else:
|
||||
tokens = list(parser.tokenize(caption))
|
||||
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
|
||||
else:
|
||||
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
||||
captions = {}
|
||||
for audio_idx in range(len(data)):
|
||||
audio_id = data[audio_idx]["audio_id"]
|
||||
captions[audio_id] = []
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
||||
captions[audio_id].append({
|
||||
"audio_id": audio_id,
|
||||
"id": cap_idx,
|
||||
"caption": caption
|
||||
})
|
||||
tokenizer = PTBTokenizer()
|
||||
captions = tokenizer.tokenize(captions)
|
||||
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
||||
audio_id = data[audio_idx]["audio_id"]
|
||||
for cap_idx in range(len(data[audio_idx]["captions"])):
|
||||
tokens = captions[audio_id][cap_idx]
|
||||
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
|
||||
|
||||
if output_json:
|
||||
json.dump(
|
||||
{ "audios": data }, open(output_json, "w"),
|
||||
indent=4, ensure_ascii=not zh)
|
||||
else:
|
||||
json.dump(
|
||||
{ "audios": data }, open(input_json, "w"),
|
||||
indent=4, ensure_ascii=not zh)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(tokenize_caption)
|
||||
178
audio_to_text/captioning/utils/train_util.py
Normal file
178
audio_to_text/captioning/utils/train_util.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from typing import Callable, Dict, Union
|
||||
import yaml
|
||||
import torch
|
||||
from torch.optim.swa_utils import AveragedModel as torch_average_model
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pprint import pformat
|
||||
|
||||
|
||||
def load_dict_from_csv(csv, cols):
|
||||
df = pd.read_csv(csv, sep="\t")
|
||||
output = dict(zip(df[cols[0]], df[cols[1]]))
|
||||
return output
|
||||
|
||||
|
||||
def init_logger(filename, level="INFO"):
|
||||
formatter = logging.Formatter(
|
||||
"[ %(levelname)s : %(asctime)s ] - %(message)s")
|
||||
logger = logging.getLogger(__name__ + "." + filename)
|
||||
logger.setLevel(getattr(logging, level))
|
||||
# Log results to std
|
||||
# stdhandler = logging.StreamHandler(sys.stdout)
|
||||
# stdhandler.setFormatter(formatter)
|
||||
# Dump log to file
|
||||
filehandler = logging.FileHandler(filename)
|
||||
filehandler.setFormatter(formatter)
|
||||
logger.addHandler(filehandler)
|
||||
# logger.addHandler(stdhandler)
|
||||
return logger
|
||||
|
||||
|
||||
def init_obj(module, config, **kwargs):# 'captioning.models.encoder'
|
||||
obj_args = config["args"].copy()
|
||||
obj_args.update(kwargs)
|
||||
return getattr(module, config["type"])(**obj_args)
|
||||
|
||||
|
||||
def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'):
|
||||
"""pprint_dict
|
||||
|
||||
:param outputfun: function to use, defaults to sys.stdout
|
||||
:param in_dict: dict to print
|
||||
"""
|
||||
if formatter == 'yaml':
|
||||
format_fun = yaml.dump
|
||||
elif formatter == 'pretty':
|
||||
format_fun = pformat
|
||||
for line in format_fun(in_dict).split('\n'):
|
||||
outputfun(line)
|
||||
|
||||
|
||||
def merge_a_into_b(a, b):
|
||||
# merge dict a into dict b. values in a will overwrite b.
|
||||
for k, v in a.items():
|
||||
if isinstance(v, dict) and k in b:
|
||||
assert isinstance(
|
||||
b[k], dict
|
||||
), "Cannot inherit key '{}' from base!".format(k)
|
||||
merge_a_into_b(v, b[k])
|
||||
else:
|
||||
b[k] = v
|
||||
|
||||
|
||||
def load_config(config_file):
|
||||
with open(config_file, "r") as reader:
|
||||
config = yaml.load(reader, Loader=yaml.FullLoader)
|
||||
if "inherit_from" in config:
|
||||
base_config_file = config["inherit_from"]
|
||||
base_config_file = os.path.join(
|
||||
os.path.dirname(config_file), base_config_file
|
||||
)
|
||||
assert not os.path.samefile(config_file, base_config_file), \
|
||||
"inherit from itself"
|
||||
base_config = load_config(base_config_file)
|
||||
del config["inherit_from"]
|
||||
merge_a_into_b(config, base_config)
|
||||
return base_config
|
||||
return config
|
||||
|
||||
|
||||
def parse_config_or_kwargs(config_file, **kwargs):
|
||||
yaml_config = load_config(config_file)
|
||||
# passed kwargs will override yaml config
|
||||
args = dict(yaml_config, **kwargs)
|
||||
return args
|
||||
|
||||
|
||||
def store_yaml(config, config_file):
|
||||
with open(config_file, "w") as con_writer:
|
||||
yaml.dump(config, con_writer, indent=4, default_flow_style=False)
|
||||
|
||||
|
||||
class MetricImprover:
|
||||
|
||||
def __init__(self, mode):
|
||||
assert mode in ("min", "max")
|
||||
self.mode = mode
|
||||
# min: lower -> better; max: higher -> better
|
||||
self.best_value = np.inf if mode == "min" else -np.inf
|
||||
|
||||
def compare(self, x, best_x):
|
||||
return x < best_x if self.mode == "min" else x > best_x
|
||||
|
||||
def __call__(self, x):
|
||||
if self.compare(x, self.best_value):
|
||||
self.best_value = x
|
||||
return True
|
||||
return False
|
||||
|
||||
def state_dict(self):
|
||||
return self.__dict__
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
|
||||
def fix_batchnorm(model: torch.nn.Module):
|
||||
def inner(module):
|
||||
class_name = module.__class__.__name__
|
||||
if class_name.find("BatchNorm") != -1:
|
||||
module.eval()
|
||||
model.apply(inner)
|
||||
|
||||
|
||||
def load_pretrained_model(model: torch.nn.Module,
|
||||
pretrained: Union[str, Dict],
|
||||
output_fn: Callable = sys.stdout.write):
|
||||
if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
|
||||
output_fn(f"pretrained {pretrained} not exist!")
|
||||
return
|
||||
|
||||
if hasattr(model, "load_pretrained"):
|
||||
model.load_pretrained(pretrained)
|
||||
return
|
||||
|
||||
if isinstance(pretrained, dict):
|
||||
state_dict = pretrained
|
||||
else:
|
||||
state_dict = torch.load(pretrained, map_location="cpu")
|
||||
|
||||
if "model" in state_dict:
|
||||
state_dict = state_dict["model"]
|
||||
model_dict = model.state_dict()
|
||||
pretrained_dict = {
|
||||
k: v for k, v in state_dict.items() if (k in model_dict) and (
|
||||
model_dict[k].shape == v.shape)
|
||||
}
|
||||
output_fn(f"Loading pretrained keys {pretrained_dict.keys()}")
|
||||
model_dict.update(pretrained_dict)
|
||||
model.load_state_dict(model_dict, strict=True)
|
||||
|
||||
|
||||
class AveragedModel(torch_average_model):
|
||||
|
||||
def update_parameters(self, model):
|
||||
for p_swa, p_model in zip(self.parameters(), model.parameters()):
|
||||
device = p_swa.device
|
||||
p_model_ = p_model.detach().to(device)
|
||||
if self.n_averaged == 0:
|
||||
p_swa.detach().copy_(p_model_)
|
||||
else:
|
||||
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
|
||||
self.n_averaged.to(device)))
|
||||
|
||||
for b_swa, b_model in zip(list(self.buffers())[1:], model.buffers()):
|
||||
device = b_swa.device
|
||||
b_model_ = b_model.detach().to(device)
|
||||
if self.n_averaged == 0:
|
||||
b_swa.detach().copy_(b_model_)
|
||||
else:
|
||||
b_swa.detach().copy_(self.avg_fn(b_swa.detach(), b_model_,
|
||||
self.n_averaged.to(device)))
|
||||
self.n_averaged += 1
|
||||
@@ -0,0 +1,67 @@
|
||||
# coding=utf-8
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import gensim
|
||||
from gensim.models import Word2Vec
|
||||
from tqdm import tqdm
|
||||
import fire
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.getcwd())
|
||||
from utils.build_vocab import Vocabulary
|
||||
|
||||
def create_embedding(vocab_file: str,
|
||||
embed_size: int,
|
||||
output: str,
|
||||
caption_file: str = None,
|
||||
pretrained_weights_path: str = None,
|
||||
**word2vec_kwargs):
|
||||
vocabulary = torch.load(vocab_file, map_location="cpu")
|
||||
|
||||
if pretrained_weights_path:
|
||||
model = gensim.models.KeyedVectors.load_word2vec_format(
|
||||
fname=pretrained_weights_path,
|
||||
binary=True,
|
||||
)
|
||||
if model.vector_size != embed_size:
|
||||
assert embed_size < model.vector_size, f"only reduce dimension, cannot add dimesion {model.vector_size} to {embed_size}"
|
||||
from sklearn.decomposition import PCA
|
||||
pca = PCA(n_components=embed_size)
|
||||
model.vectors = pca.fit_transform(model.vectors)
|
||||
else:
|
||||
caption_df = pd.read_json(caption_file)
|
||||
caption_df["tokens"] = caption_df["tokens"].apply(lambda x: ["<start>"] + [token for token in x] + ["<end>"])
|
||||
sentences = list(caption_df["tokens"].values)
|
||||
epochs = word2vec_kwargs.get("epochs", 10)
|
||||
if "epochs" in word2vec_kwargs:
|
||||
del word2vec_kwargs["epochs"]
|
||||
model = Word2Vec(size=embed_size, min_count=1, **word2vec_kwargs)
|
||||
model.build_vocab(sentences=sentences)
|
||||
model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs)
|
||||
|
||||
word_embeddings = np.random.randn(len(vocabulary), embed_size)
|
||||
|
||||
if isinstance(model, gensim.models.word2vec.Word2Vec):
|
||||
model = model.wv
|
||||
with tqdm(total=len(vocabulary), ascii=True) as pbar:
|
||||
for word, idx in vocabulary.word2idx.items():
|
||||
try:
|
||||
word_embeddings[idx] = model.get_vector(word)
|
||||
except KeyError:
|
||||
print(f"word {word} not found in word2vec model, it is random initialized!")
|
||||
pbar.update()
|
||||
|
||||
np.save(output, word_embeddings)
|
||||
|
||||
print("Finish writing word2vec embeddings to " + output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(create_embedding)
|
||||
|
||||
|
||||
|
||||
102
audio_to_text/inference_waveform.py
Normal file
102
audio_to_text/inference_waveform.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import sys
|
||||
import os
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import audio_to_text.captioning.models
|
||||
import audio_to_text.captioning.models.encoder
|
||||
import audio_to_text.captioning.models.decoder
|
||||
import audio_to_text.captioning.utils.train_util as train_util
|
||||
|
||||
|
||||
def load_model(config, checkpoint):
|
||||
ckpt = torch.load(checkpoint, "cpu")
|
||||
encoder_cfg = config["model"]["encoder"]
|
||||
encoder = train_util.init_obj(
|
||||
audio_to_text.captioning.models.encoder,
|
||||
encoder_cfg
|
||||
)
|
||||
if "pretrained" in encoder_cfg:
|
||||
pretrained = encoder_cfg["pretrained"]
|
||||
train_util.load_pretrained_model(encoder,
|
||||
pretrained,
|
||||
sys.stdout.write)
|
||||
decoder_cfg = config["model"]["decoder"]
|
||||
if "vocab_size" not in decoder_cfg["args"]:
|
||||
decoder_cfg["args"]["vocab_size"] = len(ckpt["vocabulary"])
|
||||
decoder = train_util.init_obj(
|
||||
audio_to_text.captioning.models.decoder,
|
||||
decoder_cfg
|
||||
)
|
||||
if "word_embedding" in decoder_cfg:
|
||||
decoder.load_word_embedding(**decoder_cfg["word_embedding"])
|
||||
if "pretrained" in decoder_cfg:
|
||||
pretrained = decoder_cfg["pretrained"]
|
||||
train_util.load_pretrained_model(decoder,
|
||||
pretrained,
|
||||
sys.stdout.write)
|
||||
model = train_util.init_obj(audio_to_text.captioning.models, config["model"],
|
||||
encoder=encoder, decoder=decoder)
|
||||
train_util.load_pretrained_model(model, ckpt)
|
||||
model.eval()
|
||||
return {
|
||||
"model": model,
|
||||
"vocabulary": ckpt["vocabulary"]
|
||||
}
|
||||
|
||||
|
||||
def decode_caption(word_ids, vocabulary):
|
||||
candidate = []
|
||||
for word_id in word_ids:
|
||||
word = vocabulary[word_id]
|
||||
if word == "<end>":
|
||||
break
|
||||
elif word == "<start>":
|
||||
continue
|
||||
candidate.append(word)
|
||||
candidate = " ".join(candidate)
|
||||
return candidate
|
||||
|
||||
|
||||
class AudioCapModel(object):
|
||||
def __init__(self,weight_dir,device='cuda'):
|
||||
config = os.path.join(weight_dir,'config.yaml')
|
||||
self.config = train_util.parse_config_or_kwargs(config)
|
||||
checkpoint = os.path.join(weight_dir,'swa.pth')
|
||||
resumed = load_model(self.config, checkpoint)
|
||||
model = resumed["model"]
|
||||
self.vocabulary = resumed["vocabulary"]
|
||||
self.model = model.to(device)
|
||||
self.device = device
|
||||
|
||||
def caption(self,audio_list):
|
||||
if isinstance(audio_list,np.ndarray):
|
||||
audio_list = [audio_list]
|
||||
elif isinstance(audio_list,str):
|
||||
audio_list = [librosa.load(audio_list,sr=32000)[0]]
|
||||
|
||||
captions = []
|
||||
for wav in audio_list:
|
||||
inputwav = torch.as_tensor(wav).float().unsqueeze(0).to(self.device)
|
||||
wav_len = torch.LongTensor([len(wav)])
|
||||
input_dict = {
|
||||
"mode": "inference",
|
||||
"wav": inputwav,
|
||||
"wav_len": wav_len,
|
||||
"specaug": False,
|
||||
"sample_method": "beam",
|
||||
}
|
||||
print(input_dict)
|
||||
out_dict = self.model(input_dict)
|
||||
caption_batch = [decode_caption(seq, self.vocabulary) for seq in \
|
||||
out_dict["seq"].cpu().numpy()]
|
||||
captions.extend(caption_batch)
|
||||
return captions
|
||||
|
||||
|
||||
|
||||
def __call__(self, audio_list):
|
||||
return self.caption(audio_list)
|
||||
|
||||
|
||||
|
||||
@@ -26,4 +26,9 @@ wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/Gene
|
||||
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/spk_map.json
|
||||
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/train_f0s_mean_std.npy
|
||||
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/word_set.json
|
||||
|
||||
wget -P text_to_speech/checkpoints/hifi_lj -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/text_to_speech/checkpoints/hifi_lj/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/text_to_speech/checkpoints/hifi_lj/model_ckpt_steps_2076000.ckpt
|
||||
wget -P text_to_speech/checkpoints/ljspeech/ps_adv_baseline -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/text_to_speech/checkpoints/ljspeech/ps_adv_baseline/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/checkpoints/ljspeech/ps_adv_baseline/model_ckpt_steps_160000.ckpt https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/checkpoints/ljspeech/ps_adv_baseline/model_ckpt_steps_160001.ckpt
|
||||
# Audio to text
|
||||
wget -P audio_to_text/audiocaps_cntrstv_cnn14rnn_trm -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/swa.pth
|
||||
wget -P audio_to_text/clotho_cntrstv_cnn14rnn_trm -i https://huggingface.co/AIGC-Audio/AudioGPT/blob/main/audio_to_text/clotho_cntrstv_cnn14rnn_trm/config.yaml https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/clotho_cntrstv_cnn14rnn_trm/swa.pth
|
||||
wget -P audio_to_text/pretrained_feature_extractors https://huggingface.co/AIGC-Audio/AudioGPT/resolve/main/audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth
|
||||
Reference in New Issue
Block a user