mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 20:07:58 +01:00
Merge branch 'main' of github.com:lmzjms/AudioGPT into main
This commit is contained in:
15
README.md
15
README.md
@@ -8,8 +8,6 @@
|
|||||||
|
|
||||||
## Capabilities
|
## Capabilities
|
||||||
|
|
||||||
Up-to-date link: https://cdb7b543afd1c8e8.gradio.app
|
|
||||||
|
|
||||||
Here we list the capability of AudioGPT at this time. More supported models and tasks are comming soon. For prompt examples, refer to [asset](assets/README.md).
|
Here we list the capability of AudioGPT at this time. More supported models and tasks are comming soon. For prompt examples, refer to [asset](assets/README.md).
|
||||||
|
|
||||||
### Speech
|
### Speech
|
||||||
@@ -18,8 +16,8 @@ Here we list the capability of AudioGPT at this time. More supported models and
|
|||||||
| Text-to-Speech | [FastSpeech](), [SyntaSpeech](), [VITS]() | Yes (WIP) |
|
| Text-to-Speech | [FastSpeech](), [SyntaSpeech](), [VITS]() | Yes (WIP) |
|
||||||
| Style Transfer | [GenerSpeech]() | Yes |
|
| Style Transfer | [GenerSpeech]() | Yes |
|
||||||
| Speech Recognition | [whisper](), [Conformer]() | Yes |
|
| Speech Recognition | [whisper](), [Conformer]() | Yes |
|
||||||
| Speech Enhancement | [ConvTasNet]() | WIP |
|
| Speech Enhancement | [ConvTasNet]() | Yes (WIP) |
|
||||||
| Speech Separation | [TF-GridNet]() | WIP |
|
| Speech Separation | [TF-GridNet]() | Yes (WIP) |
|
||||||
| Speech Translation | [Multi-decoder]() | WIP |
|
| Speech Translation | [Multi-decoder]() | WIP |
|
||||||
| Mono-to-Binaural | [NeuralWarp]() | Yes |
|
| Mono-to-Binaural | [NeuralWarp]() | Yes |
|
||||||
|
|
||||||
@@ -46,15 +44,6 @@ Here we list the capability of AudioGPT at this time. More supported models and
|
|||||||
|:-------------------------:|:-------------------------------:|:----------:|
|
|:-------------------------:|:-------------------------------:|:----------:|
|
||||||
| Talking Head Synthesis | [GeneFace]() | Yes (WIP) |
|
| Talking Head Synthesis | [GeneFace]() | Yes (WIP) |
|
||||||
|
|
||||||
## Internal Version Updates
|
|
||||||
4.6 Support Sound Extraction/Detection\
|
|
||||||
4.3 Support huggingface demo space\
|
|
||||||
4.1 Support Audio inpainting and clean codes\
|
|
||||||
3.27 Support Style Transfer/Talking head Synthesis\
|
|
||||||
3.23 Support Text-to-Sing\
|
|
||||||
3.21 Support Image-to-Audio\
|
|
||||||
3.19 Support Speech Recognition\
|
|
||||||
3.17 Support Text-to-Audio
|
|
||||||
|
|
||||||
## Todo
|
## Todo
|
||||||
- [x] clean text to sing/speech code
|
- [x] clean text to sing/speech code
|
||||||
|
|||||||
130
audio-chatgpt.py
130
audio-chatgpt.py
@@ -9,9 +9,7 @@ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mono2
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import librosa
|
import librosa
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import StableDiffusionPipeline
|
|
||||||
from langchain.agents.initialize import initialize_agent
|
from langchain.agents.initialize import initialize_agent
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.chains.conversation.memory import ConversationBufferMemory
|
from langchain.chains.conversation.memory import ConversationBufferMemory
|
||||||
@@ -22,32 +20,18 @@ import soundfile
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration
|
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.data.extract_mel_spectrogram import TRANSFORMS_16000
|
from ldm.data.extract_mel_spectrogram import TRANSFORMS_16000
|
||||||
from vocoder.bigvgan.models import VocoderBigVGAN
|
from vocoder.bigvgan.models import VocoderBigVGAN
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from wav_evaluation.models.CLAPWrapper import CLAPWrapper
|
|
||||||
from audio_to_text.inference_waveform import AudioCapModel
|
|
||||||
import whisper
|
import whisper
|
||||||
from inference.svs.ds_e2e import DiffSingerE2EInfer
|
|
||||||
from inference.tts.GenerSpeech import GenerSpeechInfer
|
|
||||||
from inference.tts.PortaSpeech import TTSInference
|
|
||||||
from utils.hparams import set_hparams
|
from utils.hparams import set_hparams
|
||||||
from utils.hparams import hparams as hp
|
from utils.hparams import hparams as hp
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
import librosa
|
import librosa
|
||||||
from audio_infer.utils import config as detection_config
|
from audio_infer.utils import config as detection_config
|
||||||
from audio_infer.pytorch.models import PVT
|
from audio_infer.pytorch.models import PVT
|
||||||
from src.models import BinauralNetwork
|
|
||||||
from sound_extraction.model.LASSNet import LASSNet
|
|
||||||
from sound_extraction.utils.stft import STFT
|
|
||||||
from sound_extraction.utils.wav_io import load_wav, save_wav
|
|
||||||
from target_sound_detection.src import models as tsd_models
|
|
||||||
from target_sound_detection.src.models import event_labels
|
|
||||||
from target_sound_detection.src.utils import median_filter, decode_with_timestamps
|
|
||||||
from espnet2.bin.svs_inference import SingingGenerate
|
|
||||||
import clip
|
import clip
|
||||||
import numpy as np
|
import numpy as np
|
||||||
AUDIO_CHATGPT_PREFIX = """AudioGPT
|
AUDIO_CHATGPT_PREFIX = """AudioGPT
|
||||||
@@ -98,41 +82,6 @@ def cut_dialogue_history(history_memory, keep_last_n_words = 500):
|
|||||||
return '\n' + '\n'.join(paragraphs)
|
return '\n' + '\n'.join(paragraphs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_model(config, ckpt, device):
|
|
||||||
config = OmegaConf.load(config)
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
|
|
||||||
|
|
||||||
model = model.to(device)
|
|
||||||
model.cond_stage_model.to(model.device)
|
|
||||||
model.cond_stage_model.device = model.device
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
return sampler
|
|
||||||
|
|
||||||
def initialize_model_inpaint(config, ckpt):
|
|
||||||
config = OmegaConf.load(config)
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = model.to(device)
|
|
||||||
print(model.device,device,model.cond_stage_model.device)
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
return sampler
|
|
||||||
|
|
||||||
def select_best_audio(prompt,wav_list):
|
|
||||||
clap_model = CLAPWrapper('useful_ckpts/CLAP/CLAP_weights_2022.pth','useful_ckpts/CLAP/config.yml',use_cuda=torch.cuda.is_available())
|
|
||||||
text_embeddings = clap_model.get_text_embeddings([prompt])
|
|
||||||
score_list = []
|
|
||||||
for data in wav_list:
|
|
||||||
sr,wav = data
|
|
||||||
audio_embeddings = clap_model.get_audio_embeddings([(torch.FloatTensor(wav),sr)], resample=True)
|
|
||||||
score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False).squeeze().cpu().numpy()
|
|
||||||
score_list.append(score)
|
|
||||||
max_index = np.array(score_list).argmax()
|
|
||||||
print(score_list,max_index)
|
|
||||||
return wav_list[max_index]
|
|
||||||
|
|
||||||
def merge_audio(audio_path_1, audio_path_2):
|
def merge_audio(audio_path_1, audio_path_2):
|
||||||
merged_signal = []
|
merged_signal = []
|
||||||
sr_1, signal_1 = wavfile.read(audio_path_1)
|
sr_1, signal_1 = wavfile.read(audio_path_1)
|
||||||
@@ -147,6 +96,9 @@ def merge_audio(audio_path_1, audio_path_2):
|
|||||||
|
|
||||||
class T2I:
|
class T2I:
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
from transformers import pipeline
|
||||||
print("Initializing T2I to %s" % device)
|
print("Initializing T2I to %s" % device)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||||
@@ -166,6 +118,7 @@ class T2I:
|
|||||||
|
|
||||||
class ImageCaptioning:
|
class ImageCaptioning:
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
|
from transformers import BlipProcessor, BlipForConditionalGeneration
|
||||||
print("Initializing ImageCaptioning to %s" % device)
|
print("Initializing ImageCaptioning to %s" % device)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||||
@@ -181,9 +134,20 @@ class T2A:
|
|||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
print("Initializing Make-An-Audio to %s" % device)
|
print("Initializing Make-An-Audio to %s" % device)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.sampler = initialize_model('text_to_audio/Make_An_Audio/configs/text_to_audio/txt2audio_args.yaml', 'text_to_audio/Make_An_Audio/useful_ckpts/ta40multi_epoch=000085.ckpt', device=device)
|
self.sampler = self._initialize_model('text_to_audio/Make_An_Audio/configs/text_to_audio/txt2audio_args.yaml', 'text_to_audio/Make_An_Audio/useful_ckpts/ta40multi_epoch=000085.ckpt', device=device)
|
||||||
self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio/vocoder/logs/bigv16k53w',device=device)
|
self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio/vocoder/logs/bigv16k53w',device=device)
|
||||||
|
|
||||||
|
def _initialize_model(self, config, ckpt, device):
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.load_state_dict(torch.load(ckpt, map_location='cpu')["state_dict"], strict=False)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
model.cond_stage_model.to(model.device)
|
||||||
|
model.cond_stage_model.device = model.device
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
return sampler
|
||||||
|
|
||||||
def txt2audio(self, text, seed = 55, scale = 1.5, ddim_steps = 100, n_samples = 3, W = 624, H = 80):
|
def txt2audio(self, text, seed = 55, scale = 1.5, ddim_steps = 100, n_samples = 3, W = 624, H = 80):
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
prng = np.random.RandomState(seed)
|
prng = np.random.RandomState(seed)
|
||||||
@@ -208,9 +172,25 @@ class T2A:
|
|||||||
for idx,spec in enumerate(x_samples_ddim):
|
for idx,spec in enumerate(x_samples_ddim):
|
||||||
wav = self.vocoder.vocode(spec)
|
wav = self.vocoder.vocode(spec)
|
||||||
wav_list.append((SAMPLE_RATE,wav))
|
wav_list.append((SAMPLE_RATE,wav))
|
||||||
best_wav = select_best_audio(text, wav_list)
|
best_wav = self.select_best_audio(text, wav_list)
|
||||||
return best_wav
|
return best_wav
|
||||||
|
|
||||||
|
def select_best_audio(self, prompt, wav_list):
|
||||||
|
from wav_evaluation.models.CLAPWrapper import CLAPWrapper
|
||||||
|
clap_model = CLAPWrapper('useful_ckpts/CLAP/CLAP_weights_2022.pth', 'useful_ckpts/CLAP/config.yml',
|
||||||
|
use_cuda=torch.cuda.is_available())
|
||||||
|
text_embeddings = clap_model.get_text_embeddings([prompt])
|
||||||
|
score_list = []
|
||||||
|
for data in wav_list:
|
||||||
|
sr, wav = data
|
||||||
|
audio_embeddings = clap_model.get_audio_embeddings([(torch.FloatTensor(wav), sr)], resample=True)
|
||||||
|
score = clap_model.compute_similarity(audio_embeddings, text_embeddings,
|
||||||
|
use_logit_scale=False).squeeze().cpu().numpy()
|
||||||
|
score_list.append(score)
|
||||||
|
max_index = np.array(score_list).argmax()
|
||||||
|
print(score_list, max_index)
|
||||||
|
return wav_list[max_index]
|
||||||
|
|
||||||
def inference(self, text, seed = 55, scale = 1.5, ddim_steps = 100, n_samples = 3, W = 624, H = 80):
|
def inference(self, text, seed = 55, scale = 1.5, ddim_steps = 100, n_samples = 3, W = 624, H = 80):
|
||||||
melbins,mel_len = 80,624
|
melbins,mel_len = 80,624
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -228,8 +208,20 @@ class I2A:
|
|||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
print("Initializing Make-An-Audio-Image to %s" % device)
|
print("Initializing Make-An-Audio-Image to %s" % device)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.sampler = initialize_model('text_to_audio/Make_An_Audio_img/configs/img_to_audio/img2audio_args.yaml', 'text_to_audio/Make_An_Audio_img/useful_ckpts/ta54_epoch=000216.ckpt', device=device)
|
self.sampler = self._initialize_model('text_to_audio/Make_An_Audio_img/configs/img_to_audio/img2audio_args.yaml', 'text_to_audio/Make_An_Audio_img/useful_ckpts/ta54_epoch=000216.ckpt', device=device)
|
||||||
self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio_img/vocoder/logs/bigv16k53w',device=device)
|
self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio_img/vocoder/logs/bigv16k53w',device=device)
|
||||||
|
|
||||||
|
def _initialize_model(self, config, ckpt, device):
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.load_state_dict(torch.load(ckpt, map_location='cpu')["state_dict"], strict=False)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
model.cond_stage_model.to(model.device)
|
||||||
|
model.cond_stage_model.device = model.device
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
return sampler
|
||||||
|
|
||||||
def img2audio(self, image, seed = 55, scale = 3, ddim_steps = 100, W = 624, H = 80):
|
def img2audio(self, image, seed = 55, scale = 3, ddim_steps = 100, W = 624, H = 80):
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
n_samples = 1 # only support 1 sample
|
n_samples = 1 # only support 1 sample
|
||||||
@@ -275,6 +267,7 @@ class I2A:
|
|||||||
|
|
||||||
class TTS:
|
class TTS:
|
||||||
def __init__(self, device=None):
|
def __init__(self, device=None):
|
||||||
|
from inference.tts.PortaSpeech import TTSInference
|
||||||
if device is None:
|
if device is None:
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
print("Initializing PortaSpeech to %s" % device)
|
print("Initializing PortaSpeech to %s" % device)
|
||||||
@@ -297,6 +290,7 @@ class TTS:
|
|||||||
|
|
||||||
class T2S:
|
class T2S:
|
||||||
def __init__(self, device= None):
|
def __init__(self, device= None):
|
||||||
|
from inference.svs.ds_e2e import DiffSingerE2EInfer
|
||||||
if device is None:
|
if device is None:
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
print("Initializing DiffSinger to %s" % device)
|
print("Initializing DiffSinger to %s" % device)
|
||||||
@@ -339,6 +333,7 @@ class T2S:
|
|||||||
|
|
||||||
class t2s_VISinger:
|
class t2s_VISinger:
|
||||||
def __init__(self, device=None):
|
def __init__(self, device=None):
|
||||||
|
from espnet2.bin.svs_inference import SingingGenerate
|
||||||
if device is None:
|
if device is None:
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
print("Initializing VISingere to %s" % device)
|
print("Initializing VISingere to %s" % device)
|
||||||
@@ -380,6 +375,7 @@ class t2s_VISinger:
|
|||||||
|
|
||||||
class TTS_OOD:
|
class TTS_OOD:
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
|
from inference.tts.GenerSpeech import GenerSpeechInfer
|
||||||
if device is None:
|
if device is None:
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
print("Initializing GenerSpeech to %s" % device)
|
print("Initializing GenerSpeech to %s" % device)
|
||||||
@@ -416,9 +412,20 @@ class Inpaint:
|
|||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
print("Initializing Make-An-Audio-inpaint to %s" % device)
|
print("Initializing Make-An-Audio-inpaint to %s" % device)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.sampler = initialize_model_inpaint('text_to_audio/Make_An_Audio_inpaint/configs/inpaint/txt2audio_args.yaml', 'text_to_audio/Make_An_Audio_inpaint/useful_ckpts/inpaint7_epoch00047.ckpt')
|
self.sampler = self._initialize_model_inpaint('text_to_audio/Make_An_Audio_inpaint/configs/inpaint/txt2audio_args.yaml', 'text_to_audio/Make_An_Audio_inpaint/useful_ckpts/inpaint7_epoch00047.ckpt')
|
||||||
self.vocoder = VocoderBigVGAN('./vocoder/logs/bigv16k53w',device=device)
|
self.vocoder = VocoderBigVGAN('./vocoder/logs/bigv16k53w',device=device)
|
||||||
self.cmap_transform = matplotlib.cm.viridis
|
self.cmap_transform = matplotlib.cm.viridis
|
||||||
|
|
||||||
|
def _initialize_model_inpaint(self, config, ckpt):
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.load_state_dict(torch.load(ckpt, map_location='cpu')["state_dict"], strict=False)
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
print(model.device, device, model.cond_stage_model.device)
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
return sampler
|
||||||
|
|
||||||
def make_batch_sd(self, mel, mask, num_samples=1):
|
def make_batch_sd(self, mel, mask, num_samples=1):
|
||||||
|
|
||||||
mel = torch.from_numpy(mel)[None,None,...].to(dtype=torch.float32)
|
mel = torch.from_numpy(mel)[None,None,...].to(dtype=torch.float32)
|
||||||
@@ -563,6 +570,7 @@ class ASR:
|
|||||||
|
|
||||||
class A2T:
|
class A2T:
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
|
from audio_to_text.inference_waveform import AudioCapModel
|
||||||
print("Initializing Audio-To-Text Model to %s" % device)
|
print("Initializing Audio-To-Text Model to %s" % device)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = AudioCapModel("audio_to_text/audiocaps_cntrstv_cnn14rnn_trm")
|
self.model = AudioCapModel("audio_to_text/audiocaps_cntrstv_cnn14rnn_trm")
|
||||||
@@ -659,10 +667,13 @@ class SoundDetection:
|
|||||||
|
|
||||||
class SoundExtraction:
|
class SoundExtraction:
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
|
from sound_extraction.model.LASSNet import LASSNet
|
||||||
|
from sound_extraction.utils.stft import STFT
|
||||||
|
import torch.nn as nn
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model_file = 'sound_extraction/useful_ckpts/LASSNet.pt'
|
self.model_file = 'sound_extraction/useful_ckpts/LASSNet.pt'
|
||||||
self.stft = STFT()
|
self.stft = STFT()
|
||||||
import torch.nn as nn
|
|
||||||
self.model = nn.DataParallel(LASSNet(device)).to(device)
|
self.model = nn.DataParallel(LASSNet(device)).to(device)
|
||||||
checkpoint = torch.load(self.model_file)
|
checkpoint = torch.load(self.model_file)
|
||||||
self.model.load_state_dict(checkpoint['model'])
|
self.model.load_state_dict(checkpoint['model'])
|
||||||
@@ -670,6 +681,7 @@ class SoundExtraction:
|
|||||||
|
|
||||||
def inference(self, inputs):
|
def inference(self, inputs):
|
||||||
#key = ['ref_audio', 'text']
|
#key = ['ref_audio', 'text']
|
||||||
|
from sound_extraction.utils.wav_io import load_wav, save_wav
|
||||||
val = inputs.split(",")
|
val = inputs.split(",")
|
||||||
audio_path = val[0] # audio_path, text
|
audio_path = val[0] # audio_path, text
|
||||||
text = val[1]
|
text = val[1]
|
||||||
@@ -693,6 +705,7 @@ class SoundExtraction:
|
|||||||
|
|
||||||
class Binaural:
|
class Binaural:
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
|
from src.models import BinauralNetwork
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model_file = 'mono2binaural/useful_ckpts/m2b/binaural_network.net'
|
self.model_file = 'mono2binaural/useful_ckpts/m2b/binaural_network.net'
|
||||||
self.position_file = ['mono2binaural/useful_ckpts/m2b/tx_positions.txt',
|
self.position_file = ['mono2binaural/useful_ckpts/m2b/tx_positions.txt',
|
||||||
@@ -754,6 +767,9 @@ class Binaural:
|
|||||||
|
|
||||||
class TargetSoundDetection:
|
class TargetSoundDetection:
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
|
from target_sound_detection.src import models as tsd_models
|
||||||
|
from target_sound_detection.src.models import event_labels
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.MEL_ARGS = {
|
self.MEL_ARGS = {
|
||||||
'n_mels': 64,
|
'n_mels': 64,
|
||||||
@@ -808,6 +824,8 @@ class TargetSoundDetection:
|
|||||||
return ans.index(max(ans))
|
return ans.index(max(ans))
|
||||||
|
|
||||||
def inference(self, text, audio_path):
|
def inference(self, text, audio_path):
|
||||||
|
from target_sound_detection.src.utils import median_filter, decode_with_timestamps
|
||||||
|
|
||||||
target_emb = self.build_clip(text) # torch type
|
target_emb = self.build_clip(text) # torch type
|
||||||
idx = self.cal_similarity(target_emb, self.re_embeds)
|
idx = self.cal_similarity(target_emb, self.re_embeds)
|
||||||
target_event = self.id_to_event[idx]
|
target_event = self.id_to_event[idx]
|
||||||
|
|||||||
Reference in New Issue
Block a user