mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 20:07:58 +01:00
update_tsd
This commit is contained in:
165
audio-chatgpt.py
165
audio-chatgpt.py
@@ -52,6 +52,14 @@ import librosa
|
||||
from audio_infer.utils import config as detection_config
|
||||
from audio_infer.pytorch.models import PVT
|
||||
from src.models import BinauralNetwork
|
||||
from sound_extraction.model.LASSNet import LASSNet
|
||||
from sound_extraction.utils.stft import STFT
|
||||
from sound_extraction.utils.wav_io import load_wav, save_wav
|
||||
from target_sound_detection.src import models as tsd_models
|
||||
from target_sound_detection.src.models import event_labels
|
||||
from target_sound_detection.src.utils import median_filter, decode_with_timestamps
|
||||
import clip
|
||||
import numpy as np
|
||||
import uuid
|
||||
AUDIO_CHATGPT_PREFIX = """Audio ChatGPT
|
||||
AUdio ChatGPT can not directly read audios, but it has a list of tools to finish different audio synthesis tasks. Each audio will have a file name formed as "audio/xxx.wav". When talking about audios, Audio ChatGPT is very strict to the file name and will never fabricate nonexistent files.
|
||||
@@ -507,7 +515,7 @@ class SoundDetection:
|
||||
self.fmin = 50
|
||||
self.fmax = 14000
|
||||
self.model_type = 'PVT'
|
||||
self.checkpoint_path = './audio_detection/audio_infer/useful_ckpts/220000_iterations.pth'
|
||||
self.checkpoint_path = './audio_detection/audio_infer/useful_ckpts/audio_detection.pth'
|
||||
self.classes_num = detection_config.classes_num
|
||||
self.labels = detection_config.labels
|
||||
self.frames_per_second = self.sample_rate // self.hop_size
|
||||
@@ -556,19 +564,53 @@ class SoundDetection:
|
||||
axs[1].set_xlabel('Seconds')
|
||||
axs[1].xaxis.set_ticks_position('bottom')
|
||||
plt.tight_layout()
|
||||
image_filename = os.path.join(str(uuid.uuid4())[0:8] + ".png")
|
||||
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
|
||||
plt.savefig(image_filename)
|
||||
return image_filename
|
||||
|
||||
class SoundExtraction:
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.model_file = './sound_extraction/useful_ckpts/LASSNet.pt'
|
||||
self.stft = STFT()
|
||||
import torch.nn as nn
|
||||
self.model = nn.DataParallel(LASSNet(device)).to(device)
|
||||
checkpoint = torch.load(self.model_file)
|
||||
self.model.load_state_dict(checkpoint['model'])
|
||||
self.model.eval()
|
||||
|
||||
def inference(self, inputs):
|
||||
#key = ['ref_audio', 'text']
|
||||
val = inputs.split(",")
|
||||
audio_path = val[0] # audio_path, text
|
||||
text = val[1]
|
||||
waveform = load_wav(audio_path)
|
||||
waveform = torch.tensor(waveform).transpose(1,0)
|
||||
mixed_mag, mixed_phase = self.stft.transform(waveform)
|
||||
text_query = ['[CLS] ' + text]
|
||||
mixed_mag = mixed_mag.transpose(2,1).unsqueeze(0).to(self.device)
|
||||
est_mask = self.model(mixed_mag, text_query)
|
||||
est_mag = est_mask * mixed_mag
|
||||
est_mag = est_mag.squeeze(1)
|
||||
est_mag = est_mag.permute(0, 2, 1)
|
||||
est_wav = self.stft.inverse(est_mag.cpu().detach(), mixed_phase)
|
||||
est_wav = est_wav.squeeze(0).squeeze(0).numpy()
|
||||
#est_path = f'output/est{i}.wav'
|
||||
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
||||
print('audio_filename ', audio_filename)
|
||||
save_wav(est_wav, audio_filename)
|
||||
return audio_filename
|
||||
|
||||
|
||||
class Binaural:
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.model_file = './mono2binaural/useful_ckpts/binaural_network.net'
|
||||
self.position_file = ['./mono2binaural/useful_ckpts/tx_positions.txt',
|
||||
'./mono2binaural/useful_ckpts/tx_positions2.txt',
|
||||
'./mono2binaural/useful_ckpts/tx_positions3.txt',
|
||||
'./mono2binaural/useful_ckpts/tx_positions4.txt',
|
||||
'./mono2binaural/useful_ckpts/tx_positions5.txt']
|
||||
self.model_file = './mono2binaural/useful_ckpts/m2b/binaural_network.net'
|
||||
self.position_file = ['./mono2binaural/useful_ckpts/m2b/tx_positions.txt',
|
||||
'./mono2binaural/useful_ckpts/m2b/tx_positions2.txt',
|
||||
'./mono2binaural/useful_ckpts/m2b/tx_positions3.txt',
|
||||
'./mono2binaural/useful_ckpts/m2b/tx_positions4.txt',
|
||||
'./mono2binaural/useful_ckpts/m2b/tx_positions5.txt']
|
||||
self.net = BinauralNetwork(view_dim=7,
|
||||
warpnet_layers=4,
|
||||
warpnet_channels=64,
|
||||
@@ -621,6 +663,103 @@ class Binaural:
|
||||
print(f"Processed Binaural.run, audio_filename: {audio_filename}")
|
||||
return audio_filename
|
||||
|
||||
class TargetSoundDetection:
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.MEL_ARGS = {
|
||||
'n_mels': 64,
|
||||
'n_fft': 2048,
|
||||
'hop_length': int(22050 * 20 / 1000),
|
||||
'win_length': int(22050 * 40 / 1000)
|
||||
}
|
||||
self.EPS = np.spacing(1)
|
||||
self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
|
||||
self.event_labels = event_labels
|
||||
self.id_to_event = {i : label for i, label in enumerate(self.event_labels)}
|
||||
config = torch.load('./audio_detection/target_sound_detection/useful_ckpts/tsd/run_config.pth', map_location='cpu')
|
||||
config_parameters = dict(config)
|
||||
config_parameters['tao'] = 0.6
|
||||
if 'thres' not in config_parameters.keys():
|
||||
config_parameters['thres'] = 0.5
|
||||
if 'time_resolution' not in config_parameters.keys():
|
||||
config_parameters['time_resolution'] = 125
|
||||
model_parameters = torch.load('./audio_detection/target_sound_detection/useful_ckpts/tsd/run_model_7_loss=-0.0724.pt'
|
||||
, map_location=lambda storage, loc: storage) # load parameter
|
||||
self.model = getattr(tsd_models, config_parameters['model'])(config_parameters,
|
||||
inputdim=64, outputdim=2, time_resolution=config_parameters['time_resolution'], **config_parameters['model_args'])
|
||||
self.model.load_state_dict(model_parameters)
|
||||
self.model = self.model.to(self.device).eval()
|
||||
self.re_embeds = torch.load('./audio_detection/target_sound_detection/useful_ckpts/tsd/text_emb.pth')
|
||||
self.ref_mel = torch.load('./audio_detection/target_sound_detection/useful_ckpts/tsd/ref_mel.pth')
|
||||
|
||||
def extract_feature(self, fname):
|
||||
import soundfile as sf
|
||||
y, sr = sf.read(fname, dtype='float32')
|
||||
print('y ', y.shape)
|
||||
ti = y.shape[0]/sr
|
||||
if y.ndim > 1:
|
||||
y = y.mean(1)
|
||||
y = librosa.resample(y, sr, 22050)
|
||||
lms_feature = np.log(librosa.feature.melspectrogram(y, **self.MEL_ARGS) + self.EPS).T
|
||||
return lms_feature,ti
|
||||
|
||||
def build_clip(self, text):
|
||||
text = clip.tokenize(text).to(self.device) # ["a diagram with dog", "a dog", "a cat"]
|
||||
text_features = self.clip_model.encode_text(text)
|
||||
return text_features
|
||||
|
||||
def cal_similarity(self, target, retrievals):
|
||||
ans = []
|
||||
#target =torch.from_numpy(target)
|
||||
for name in retrievals.keys():
|
||||
tmp = retrievals[name]
|
||||
#tmp = torch.from_numpy(tmp)
|
||||
s = torch.cosine_similarity(target.squeeze(), tmp.squeeze(), dim=0)
|
||||
ans.append(s.item())
|
||||
return ans.index(max(ans))
|
||||
|
||||
def inference(self, text, audio_path):
|
||||
target_emb = self.build_clip(text) # torch type
|
||||
idx = self.cal_similarity(target_emb, self.re_embeds)
|
||||
target_event = self.id_to_event[idx]
|
||||
embedding = self.ref_mel[target_event]
|
||||
embedding = torch.from_numpy(embedding)
|
||||
embedding = embedding.unsqueeze(0).to(self.device).float()
|
||||
#print('embedding ', embedding.shape)
|
||||
inputs,ti = self.extract_feature(audio_path)
|
||||
#print('ti ', ti)
|
||||
inputs = torch.from_numpy(inputs)
|
||||
inputs = inputs.unsqueeze(0).to(self.device).float()
|
||||
#print('inputs ', inputs.shape)
|
||||
decision, decision_up, logit = self.model(inputs, embedding)
|
||||
pred = decision_up.detach().cpu().numpy()
|
||||
pred = pred[:,:,0]
|
||||
frame_num = decision_up.shape[1]
|
||||
time_ratio = ti / frame_num
|
||||
filtered_pred = median_filter(pred, window_size=1, threshold=0.5)
|
||||
#print('filtered_pred ', filtered_pred)
|
||||
time_predictions = []
|
||||
for index_k in range(filtered_pred.shape[0]):
|
||||
decoded_pred = []
|
||||
decoded_pred_ = decode_with_timestamps(target_event, filtered_pred[index_k,:])
|
||||
if len(decoded_pred_) == 0: # neg deal
|
||||
decoded_pred_.append((target_event, 0, 0))
|
||||
decoded_pred.append(decoded_pred_)
|
||||
for num_batch in range(len(decoded_pred)): # when we test our model,the batch_size is 1
|
||||
cur_pred = pred[num_batch]
|
||||
# Save each frame output, for later visualization
|
||||
label_prediction = decoded_pred[num_batch] # frame predict
|
||||
# print(label_prediction)
|
||||
for event_label, onset, offset in label_prediction:
|
||||
time_predictions.append({
|
||||
'onset': onset*time_ratio,
|
||||
'offset': offset*time_ratio,})
|
||||
ans = ''
|
||||
for i,item in enumerate(time_predictions):
|
||||
ans = ans + 'segment' + str(i+1) + ' start_time: ' + str(item['onset']) + ' end_time: ' + str(item['offset']) + '\t'
|
||||
#print(ans)
|
||||
return ans
|
||||
|
||||
class ConversationBot:
|
||||
def __init__(self):
|
||||
print("Initializing AudioChatGPT")
|
||||
@@ -636,6 +775,8 @@ class ConversationBot:
|
||||
self.tts_ood = TTS_OOD(device="cuda:0")
|
||||
self.detection = SoundDetection(device="cuda:0")
|
||||
self.binaural = Binaural(device="cuda:1")
|
||||
self.extraction = SoundExtraction(device="cuda:0")
|
||||
self.TSD = TargetSoundDetection(device="cuda:1")
|
||||
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
||||
self.tools = [
|
||||
Tool(name="Generate Image From User Input Text", func=self.t2i.inference,
|
||||
@@ -672,7 +813,13 @@ class ConversationBot:
|
||||
"The input to this tool should be a string, representing the audio_path. "),
|
||||
Tool(name="Sythesize binaural audio from a mono audio input", func=self.binaural.inference,
|
||||
description="useful for when you want to transfer your mono audio into binaural audio, receives audio_path as input. "
|
||||
"The input to this tool should be a string, representing the audio_path. ")]
|
||||
"The input to this tool should be a string, representing the audio_path. "),
|
||||
Tool(name="Extract sound event from mixture audio based on language description", func=self.extraction.inference,
|
||||
description="useful for when you extract target sound from a mixture audio, you can describe the taregt sound by text, receives audio_path and text as input. "
|
||||
"The input to this tool should be a comma seperated string of two, representing mixture audio path and input text."),
|
||||
Tool(name="Detect the sound event from the audio based on your descriptions", func=self.TSD.inference,
|
||||
description="useful for when you want to know the when happens the target sound event in th audio. You can use language descriptions to instruct the model. receives text description and audio_path as input. "
|
||||
"The input to this tool should be a string, representing the answer. "),]
|
||||
self.agent = initialize_agent(
|
||||
self.tools,
|
||||
self.llm,
|
||||
|
||||
Binary file not shown.
@@ -1,5 +0,0 @@
|
||||
CUDA_VISIBLE_DEVICES=0 python3 pytorch/inference.py sound_event_detection \
|
||||
--model_type=PVT \
|
||||
--checkpoint_path=/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/dongchao/audio_chatgpt/ft_local/audio_infer/220000_iterations.pth \
|
||||
--audio_path="/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/dongchao/audio_chatgpt/ft_local/audio_infer/YDlWd7Wmdi1E.wav" \
|
||||
--cuda
|
||||
12
download.sh
12
download.sh
@@ -27,3 +27,15 @@ wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/Gene
|
||||
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/train_f0s_mean_std.npy
|
||||
wget -P data/binary/training_set https://huggingface.co/spaces/Rongjiehuang/GenerSpeech/resolve/main/data/binary/training_set/word_set.json
|
||||
|
||||
cd audio_detection/audio_infer/useful_ckpts
|
||||
wget https://huggingface.co/Dongchao/pre_trained_model/resolve/main/audio_detection.pth
|
||||
cd mono2binaural/useful_ckpts
|
||||
wget https://huggingface.co/Dongchao/pre_trained_model/resolve/main/m2b.tar.gz
|
||||
tar -zxvf m2b.tar.gz ./
|
||||
rm m2b.tar.gz
|
||||
cd audio_detection/target_sound_detection/useful_ckpts
|
||||
wget https://huggingface.co/Dongchao/pre_trained_model/resolve/main/tsd.tar.gz
|
||||
tar -zxvf tsd.tar.gz ./
|
||||
rm tsd.tar.gz
|
||||
cd sound_extraction/useful_ckpts
|
||||
wget https://huggingface.co/Dongchao/pre_trained_model/resolve/main/LASSNet.pt
|
||||
Reference in New Issue
Block a user