update_tsd

This commit is contained in:
yangdongchao
2023-04-05 16:07:59 +08:00
parent 1a7f19b244
commit a70460047f
4 changed files with 168 additions and 14 deletions

View File

@@ -52,6 +52,14 @@ import librosa
from audio_infer.utils import config as detection_config from audio_infer.utils import config as detection_config
from audio_infer.pytorch.models import PVT from audio_infer.pytorch.models import PVT
from src.models import BinauralNetwork from src.models import BinauralNetwork
from sound_extraction.model.LASSNet import LASSNet
from sound_extraction.utils.stft import STFT
from sound_extraction.utils.wav_io import load_wav, save_wav
from target_sound_detection.src import models as tsd_models
from target_sound_detection.src.models import event_labels
from target_sound_detection.src.utils import median_filter, decode_with_timestamps
import clip
import numpy as np
import uuid import uuid
AUDIO_CHATGPT_PREFIX = """Audio ChatGPT AUDIO_CHATGPT_PREFIX = """Audio ChatGPT
AUdio ChatGPT can not directly read audios, but it has a list of tools to finish different audio synthesis tasks. Each audio will have a file name formed as "audio/xxx.wav". When talking about audios, Audio ChatGPT is very strict to the file name and will never fabricate nonexistent files. AUdio ChatGPT can not directly read audios, but it has a list of tools to finish different audio synthesis tasks. Each audio will have a file name formed as "audio/xxx.wav". When talking about audios, Audio ChatGPT is very strict to the file name and will never fabricate nonexistent files.
@@ -507,7 +515,7 @@ class SoundDetection:
self.fmin = 50 self.fmin = 50
self.fmax = 14000 self.fmax = 14000
self.model_type = 'PVT' self.model_type = 'PVT'
self.checkpoint_path = './audio_detection/audio_infer/useful_ckpts/220000_iterations.pth' self.checkpoint_path = './audio_detection/audio_infer/useful_ckpts/audio_detection.pth'
self.classes_num = detection_config.classes_num self.classes_num = detection_config.classes_num
self.labels = detection_config.labels self.labels = detection_config.labels
self.frames_per_second = self.sample_rate // self.hop_size self.frames_per_second = self.sample_rate // self.hop_size
@@ -556,19 +564,53 @@ class SoundDetection:
axs[1].set_xlabel('Seconds') axs[1].set_xlabel('Seconds')
axs[1].xaxis.set_ticks_position('bottom') axs[1].xaxis.set_ticks_position('bottom')
plt.tight_layout() plt.tight_layout()
image_filename = os.path.join(str(uuid.uuid4())[0:8] + ".png") image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
plt.savefig(image_filename) plt.savefig(image_filename)
return image_filename return image_filename
class SoundExtraction:
def __init__(self, device):
self.device = device
self.model_file = './sound_extraction/useful_ckpts/LASSNet.pt'
self.stft = STFT()
import torch.nn as nn
self.model = nn.DataParallel(LASSNet(device)).to(device)
checkpoint = torch.load(self.model_file)
self.model.load_state_dict(checkpoint['model'])
self.model.eval()
def inference(self, inputs):
#key = ['ref_audio', 'text']
val = inputs.split(",")
audio_path = val[0] # audio_path, text
text = val[1]
waveform = load_wav(audio_path)
waveform = torch.tensor(waveform).transpose(1,0)
mixed_mag, mixed_phase = self.stft.transform(waveform)
text_query = ['[CLS] ' + text]
mixed_mag = mixed_mag.transpose(2,1).unsqueeze(0).to(self.device)
est_mask = self.model(mixed_mag, text_query)
est_mag = est_mask * mixed_mag
est_mag = est_mag.squeeze(1)
est_mag = est_mag.permute(0, 2, 1)
est_wav = self.stft.inverse(est_mag.cpu().detach(), mixed_phase)
est_wav = est_wav.squeeze(0).squeeze(0).numpy()
#est_path = f'output/est{i}.wav'
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
print('audio_filename ', audio_filename)
save_wav(est_wav, audio_filename)
return audio_filename
class Binaural: class Binaural:
def __init__(self, device): def __init__(self, device):
self.device = device self.device = device
self.model_file = './mono2binaural/useful_ckpts/binaural_network.net' self.model_file = './mono2binaural/useful_ckpts/m2b/binaural_network.net'
self.position_file = ['./mono2binaural/useful_ckpts/tx_positions.txt', self.position_file = ['./mono2binaural/useful_ckpts/m2b/tx_positions.txt',
'./mono2binaural/useful_ckpts/tx_positions2.txt', './mono2binaural/useful_ckpts/m2b/tx_positions2.txt',
'./mono2binaural/useful_ckpts/tx_positions3.txt', './mono2binaural/useful_ckpts/m2b/tx_positions3.txt',
'./mono2binaural/useful_ckpts/tx_positions4.txt', './mono2binaural/useful_ckpts/m2b/tx_positions4.txt',
'./mono2binaural/useful_ckpts/tx_positions5.txt'] './mono2binaural/useful_ckpts/m2b/tx_positions5.txt']
self.net = BinauralNetwork(view_dim=7, self.net = BinauralNetwork(view_dim=7,
warpnet_layers=4, warpnet_layers=4,
warpnet_channels=64, warpnet_channels=64,
@@ -621,6 +663,103 @@ class Binaural:
print(f"Processed Binaural.run, audio_filename: {audio_filename}") print(f"Processed Binaural.run, audio_filename: {audio_filename}")
return audio_filename return audio_filename
class TargetSoundDetection:
def __init__(self, device):
self.device = device
self.MEL_ARGS = {
'n_mels': 64,
'n_fft': 2048,
'hop_length': int(22050 * 20 / 1000),
'win_length': int(22050 * 40 / 1000)
}
self.EPS = np.spacing(1)
self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
self.event_labels = event_labels
self.id_to_event = {i : label for i, label in enumerate(self.event_labels)}
config = torch.load('./audio_detection/target_sound_detection/useful_ckpts/tsd/run_config.pth', map_location='cpu')
config_parameters = dict(config)
config_parameters['tao'] = 0.6
if 'thres' not in config_parameters.keys():
config_parameters['thres'] = 0.5
if 'time_resolution' not in config_parameters.keys():
config_parameters['time_resolution'] = 125
model_parameters = torch.load('./audio_detection/target_sound_detection/useful_ckpts/tsd/run_model_7_loss=-0.0724.pt'
, map_location=lambda storage, loc: storage) # load parameter
self.model = getattr(tsd_models, config_parameters['model'])(config_parameters,
inputdim=64, outputdim=2, time_resolution=config_parameters['time_resolution'], **config_parameters['model_args'])
self.model.load_state_dict(model_parameters)
self.model = self.model.to(self.device).eval()
self.re_embeds = torch.load('./audio_detection/target_sound_detection/useful_ckpts/tsd/text_emb.pth')
self.ref_mel = torch.load('./audio_detection/target_sound_detection/useful_ckpts/tsd/ref_mel.pth')
def extract_feature(self, fname):
import soundfile as sf
y, sr = sf.read(fname, dtype='float32')
print('y ', y.shape)
ti = y.shape[0]/sr
if y.ndim > 1:
y = y.mean(1)
y = librosa.resample(y, sr, 22050)
lms_feature = np.log(librosa.feature.melspectrogram(y, **self.MEL_ARGS) + self.EPS).T
return lms_feature,ti
def build_clip(self, text):
text = clip.tokenize(text).to(self.device) # ["a diagram with dog", "a dog", "a cat"]
text_features = self.clip_model.encode_text(text)
return text_features
def cal_similarity(self, target, retrievals):
ans = []
#target =torch.from_numpy(target)
for name in retrievals.keys():
tmp = retrievals[name]
#tmp = torch.from_numpy(tmp)
s = torch.cosine_similarity(target.squeeze(), tmp.squeeze(), dim=0)
ans.append(s.item())
return ans.index(max(ans))
def inference(self, text, audio_path):
target_emb = self.build_clip(text) # torch type
idx = self.cal_similarity(target_emb, self.re_embeds)
target_event = self.id_to_event[idx]
embedding = self.ref_mel[target_event]
embedding = torch.from_numpy(embedding)
embedding = embedding.unsqueeze(0).to(self.device).float()
#print('embedding ', embedding.shape)
inputs,ti = self.extract_feature(audio_path)
#print('ti ', ti)
inputs = torch.from_numpy(inputs)
inputs = inputs.unsqueeze(0).to(self.device).float()
#print('inputs ', inputs.shape)
decision, decision_up, logit = self.model(inputs, embedding)
pred = decision_up.detach().cpu().numpy()
pred = pred[:,:,0]
frame_num = decision_up.shape[1]
time_ratio = ti / frame_num
filtered_pred = median_filter(pred, window_size=1, threshold=0.5)
#print('filtered_pred ', filtered_pred)
time_predictions = []
for index_k in range(filtered_pred.shape[0]):
decoded_pred = []
decoded_pred_ = decode_with_timestamps(target_event, filtered_pred[index_k,:])
if len(decoded_pred_) == 0: # neg deal
decoded_pred_.append((target_event, 0, 0))
decoded_pred.append(decoded_pred_)
for num_batch in range(len(decoded_pred)): # when we test our model,the batch_size is 1
cur_pred = pred[num_batch]
# Save each frame output, for later visualization
label_prediction = decoded_pred[num_batch] # frame predict
# print(label_prediction)
for event_label, onset, offset in label_prediction:
time_predictions.append({
'onset': onset*time_ratio,
'offset': offset*time_ratio,})
ans = ''
for i,item in enumerate(time_predictions):
ans = ans + 'segment' + str(i+1) + ' start_time: ' + str(item['onset']) + ' end_time: ' + str(item['offset']) + '\t'
#print(ans)
return ans
class ConversationBot: class ConversationBot:
def __init__(self): def __init__(self):
print("Initializing AudioChatGPT") print("Initializing AudioChatGPT")
@@ -636,6 +775,8 @@ class ConversationBot:
self.tts_ood = TTS_OOD(device="cuda:0") self.tts_ood = TTS_OOD(device="cuda:0")
self.detection = SoundDetection(device="cuda:0") self.detection = SoundDetection(device="cuda:0")
self.binaural = Binaural(device="cuda:1") self.binaural = Binaural(device="cuda:1")
self.extraction = SoundExtraction(device="cuda:0")
self.TSD = TargetSoundDetection(device="cuda:1")
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output') self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
self.tools = [ self.tools = [
Tool(name="Generate Image From User Input Text", func=self.t2i.inference, Tool(name="Generate Image From User Input Text", func=self.t2i.inference,
@@ -672,7 +813,13 @@ class ConversationBot:
"The input to this tool should be a string, representing the audio_path. "), "The input to this tool should be a string, representing the audio_path. "),
Tool(name="Sythesize binaural audio from a mono audio input", func=self.binaural.inference, Tool(name="Sythesize binaural audio from a mono audio input", func=self.binaural.inference,
description="useful for when you want to transfer your mono audio into binaural audio, receives audio_path as input. " description="useful for when you want to transfer your mono audio into binaural audio, receives audio_path as input. "
"The input to this tool should be a string, representing the audio_path. ")] "The input to this tool should be a string, representing the audio_path. "),
Tool(name="Extract sound event from mixture audio based on language description", func=self.extraction.inference,
description="useful for when you extract target sound from a mixture audio, you can describe the taregt sound by text, receives audio_path and text as input. "
"The input to this tool should be a comma seperated string of two, representing mixture audio path and input text."),
Tool(name="Detect the sound event from the audio based on your descriptions", func=self.TSD.inference,
description="useful for when you want to know the when happens the target sound event in th audio. You can use language descriptions to instruct the model. receives text description and audio_path as input. "
"The input to this tool should be a string, representing the answer. "),]
self.agent = initialize_agent( self.agent = initialize_agent(
self.tools, self.tools,
self.llm, self.llm,

View File

@@ -1,5 +0,0 @@
CUDA_VISIBLE_DEVICES=0 python3 pytorch/inference.py sound_event_detection \
--model_type=PVT \
--checkpoint_path=/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/dongchao/audio_chatgpt/ft_local/audio_infer/220000_iterations.pth \
--audio_path="/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/dongchao/audio_chatgpt/ft_local/audio_infer/YDlWd7Wmdi1E.wav" \
--cuda

View File

@@ -27,3 +27,15 @@ wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/Gene
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/train_f0s_mean_std.npy wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/train_f0s_mean_std.npy
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/word_set.json wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/word_set.json
cd audio_detection/audio_infer/useful_ckpts
wget https://huggingface.co/Dongchao/pre_trained_model/resolve/main/audio_detection.pth
cd mono2binaural/useful_ckpts
wget https://huggingface.co/Dongchao/pre_trained_model/resolve/main/m2b.tar.gz
tar -zxvf m2b.tar.gz ./
rm m2b.tar.gz
cd audio_detection/target_sound_detection/useful_ckpts
wget https://huggingface.co/Dongchao/pre_trained_model/resolve/main/tsd.tar.gz
tar -zxvf tsd.tar.gz ./
rm tsd.tar.gz
cd sound_extraction/useful_ckpts
wget https://huggingface.co/Dongchao/pre_trained_model/resolve/main/LASSNet.pt