mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 11:57:58 +01:00
Update audio-chatgpt.py
This commit is contained in:
116
audio-chatgpt.py
116
audio-chatgpt.py
@@ -2,6 +2,9 @@ import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'text_to_sing/DiffSinger'))
|
||||
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'text_to_audio/Make_An_Audio'))
|
||||
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'text_to_audio/Make_An_Audio_img'))
|
||||
import gradio as gr
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||
import torch
|
||||
@@ -28,10 +31,10 @@ from ldm.util import instantiate_from_config
|
||||
from ldm.data.extract_mel_spectrogram import TRANSFORMS_16000
|
||||
from pathlib import Path
|
||||
from vocoder.hifigan.modules import VocoderHifigan
|
||||
from Make_An_Audio_img.vocoder.bigvgan.models import VocoderBigVGAN
|
||||
from vocoder.bigvgan.models import VocoderBigVGAN
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from wav_evaluation.models.CLAPWrapper import CLAPWrapper
|
||||
from Make_An_Audio_img.ldm.util import instantiate_from_config as instantiate_from_config_make_an_audio_img
|
||||
from inference.svs.ds_e2e import DiffSingerE2EInfer
|
||||
import whisper
|
||||
|
||||
AUDIO_CHATGPT_PREFIX = """Audio ChatGPT
|
||||
@@ -60,7 +63,7 @@ Thought: Do I need to use a tool? No
|
||||
"""
|
||||
|
||||
AUDIO_CHATGPT_SUFFIX = """You are very strict to the filename correctness and will never fake a file name if not exists.
|
||||
You will remember to provide the image file name loyally if it's provided in the last tool observation.
|
||||
You will remember to provide the audio file name loyally if it's provided in the last tool observation.
|
||||
|
||||
Begin!
|
||||
|
||||
@@ -69,8 +72,7 @@ Previous conversation history:
|
||||
New input: {input}
|
||||
Thought: Do I need to use a tool? {agent_scratchpad}"""
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
temp_audio_filename = "audio/c00d9240.wav"
|
||||
#temp_audio_filename = "audio/c00d9240.wav"
|
||||
|
||||
|
||||
def cut_dialogue_history(history_memory, keep_last_n_words = 500):
|
||||
@@ -116,20 +118,9 @@ def initialize_model(config, ckpt, device):
|
||||
sampler = DDIMSampler(model)
|
||||
return sampler
|
||||
|
||||
clap_model = CLAPWrapper('useful_ckpts/CLAP/CLAP_weights_2022.pth','useful_ckpts/CLAP/config.yml',use_cuda=torch.cuda.is_available())
|
||||
|
||||
def initialize_model_img(config, ckpt, device):
|
||||
config = OmegaConf.load(config)
|
||||
model = instantiate_from_config_make_an_audio_img(config.model)
|
||||
model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
|
||||
|
||||
model = model.to(device)
|
||||
model.cond_stage_model.to(model.device)
|
||||
model.cond_stage_model.device = model.device
|
||||
sampler = DDIMSampler(model)
|
||||
return sampler
|
||||
|
||||
def select_best_audio(prompt,wav_list):
|
||||
clap_model = CLAPWrapper('useful_ckpts/CLAP/CLAP_weights_2022.pth','useful_ckpts/CLAP/config.yml',use_cuda=torch.cuda.is_available())
|
||||
text_embeddings = clap_model.get_text_embeddings([prompt])
|
||||
score_list = []
|
||||
for data in wav_list:
|
||||
@@ -210,6 +201,7 @@ class T2A:
|
||||
self.vocoder = VocoderHifigan('vocoder/logs/hifi_0127',device=device)
|
||||
|
||||
def txt2audio(self, text, seed = 55, scale = 1.5, ddim_steps = 100, n_samples = 3, W = 624, H = 80):
|
||||
SAMPLE_RATE = 16000
|
||||
prng = np.random.RandomState(seed)
|
||||
start_code = prng.randn(n_samples, self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
|
||||
start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
|
||||
@@ -236,7 +228,6 @@ class T2A:
|
||||
return best_wav
|
||||
|
||||
def inference(self, text, seed = 55, scale = 1.5, ddim_steps = 100, n_samples = 3, W = 624, H = 80):
|
||||
global temp_audio_filename
|
||||
melbins,mel_len = 80,624
|
||||
with torch.no_grad():
|
||||
result = self.txt2audio(
|
||||
@@ -245,7 +236,6 @@ class T2A:
|
||||
W = mel_len
|
||||
)
|
||||
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
||||
temp_audio_filename = audio_filename
|
||||
soundfile.write(audio_filename, result[1], samplerate = 16000)
|
||||
print(f"Processed T2I.run, text: {text}, audio_filename: {audio_filename}")
|
||||
return audio_filename
|
||||
@@ -254,9 +244,10 @@ class I2A:
|
||||
def __init__(self, device):
|
||||
print("Initializing Make-An-Audio-Image to %s" % device)
|
||||
self.device = device
|
||||
self.sampler = initialize_model_img('Make_An_Audio_img/configs/img_to_audio/img2audio_args.yaml', 'Make_An_Audio_img/useful_ckpts/ta54_epoch=000216.ckpt', device=device)
|
||||
self.vocoder = VocoderBigVGAN('Make_An_Audio_img/vocoder/logs/bigv16k53w',device=device)
|
||||
self.sampler = initialize_model('text_to_audio/Make_An_Audio_img/configs/img_to_audio/img2audio_args.yaml', 'text_to_audio/Make_An_Audio_img/useful_ckpts/ta54_epoch=000216.ckpt', device=device)
|
||||
self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio_img/vocoder/logs/bigv16k53w',device=device)
|
||||
def img2audio(self, image, seed = 55, scale = 3, ddim_steps = 100, W = 624, H = 80):
|
||||
SAMPLE_RATE = 16000
|
||||
n_samples = 1 # only support 1 sample
|
||||
prng = np.random.RandomState(seed)
|
||||
start_code = prng.randn(n_samples, self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
|
||||
@@ -286,7 +277,6 @@ class I2A:
|
||||
best_wav = wav_list[0]
|
||||
return best_wav
|
||||
def inference(self, image, seed = 55, scale = 3, ddim_steps = 100, W = 624, H = 80):
|
||||
global temp_audio_filename
|
||||
melbins,mel_len = 80,624
|
||||
with torch.no_grad():
|
||||
result = self.img2audio(
|
||||
@@ -295,16 +285,43 @@ class I2A:
|
||||
W=mel_len
|
||||
)
|
||||
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
||||
temp_audio_filename = audio_filename
|
||||
soundfile.write(audio_filename, result[1], samplerate = 16000)
|
||||
print(f"Processed I2a.run, image_filename: {image}, audio_filename: {audio_filename}")
|
||||
return audio_filename
|
||||
|
||||
class T2S:
|
||||
def __init__(self, device= None):
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print("Initializing DiffSinger to %s" % device)
|
||||
self.device = device
|
||||
exp_name = 'text_to_sing/DiffSinger/checkpoints/0831_opencpop_ds1000'
|
||||
exp_name = 'checkpoints/0831_opencpop_ds1000'
|
||||
config= 'text_to_sing/DiffSinger/usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml'
|
||||
from utils.hparams import set_hparams
|
||||
from utils.hparams import hparams as hp
|
||||
set_hparams(config= config,exp_name=exp_name, print_hparams=False)
|
||||
self.hp = hp
|
||||
self.pipe = DiffSingerE2EInfer(self.hp)
|
||||
def inference(self, inputs):
|
||||
key = ['text', 'notes', 'notes_duration']
|
||||
val = inputs.split(",")
|
||||
print(val)
|
||||
inp = {k:v for k,v in zip(key,val)}
|
||||
print(inp)
|
||||
wav = self.pipe.infer_once(inp)
|
||||
wav *= 32767
|
||||
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
||||
soundfile.write(audio_filename, wav.astype(np.int16), self.hp['audio_sample_rate'])
|
||||
print(f"Processed T2S.run, text: {val[0]}, notes: {val[1]}, notes duration: {val[2]}, audio_filename: {audio_filename}")
|
||||
return audio_filename
|
||||
|
||||
# need to debug
|
||||
class Inpaint:
|
||||
def __init__(self, device):
|
||||
print("Initializing Make-An-Audio-inpaint to %s" % device)
|
||||
self.device = device
|
||||
self.sampler = initialize_model('Make_An_Audio_inpaint/configs/inpaint/txt2audio_args.yaml', 'Make_An_Audio_inpaint/useful_ckpts/inpaint7_epoch00047.ckpt')
|
||||
self.sampler = initialize_model('text_to_audio/Make_An_Audio_inpaint/configs/inpaint/txt2audio_args.yaml', 'text_to_audio/Make_An_Audio_inpaint/useful_ckpts/inpaint7_epoch00047.ckpt')
|
||||
self.vocoder = VocoderBigVGAN('./vocoder/logs/bigv16k53w',device=device)
|
||||
def make_batch_sd(mel, mask, num_samples=1):
|
||||
|
||||
@@ -410,7 +427,6 @@ class ASR:
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
mel = whisper.log_mel_spectrogram(audio).to(self.device)
|
||||
_, probs = self.model.detect_language(mel)
|
||||
#print(f"Detected language: {max(probs, key=probs.get)}")
|
||||
options = whisper.DecodingOptions()
|
||||
result = whisper.decode(self.model, mel, options)
|
||||
return result.text
|
||||
@@ -419,9 +435,10 @@ class ConversationBot:
|
||||
def __init__(self):
|
||||
print("Initializing AudioChatGPT")
|
||||
self.llm = OpenAI(temperature=0)
|
||||
self.t2i = T2I(device="cuda:1")
|
||||
self.t2i = T2I(device="cuda:0")
|
||||
self.i2t = ImageCaptioning(device="cuda:1")
|
||||
self.t2a = T2A(device="cuda:0")
|
||||
self.t2s = T2S(device="cuda:2")
|
||||
self.i2a = I2A(device="cuda:1")
|
||||
self.asr = ASR(device="cuda:1")
|
||||
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
||||
@@ -435,13 +452,16 @@ class ConversationBot:
|
||||
Tool(name="Generate Audio From User Input Text", func=self.t2a.inference,
|
||||
description="useful for when you want to generate an audio from a user input text and it saved it to a file."
|
||||
"The input to this tool should be a string, representing the text used to generate audio."),
|
||||
Tool(name="Generate singing voice From User Input Text", func=self.t2s.inference,
|
||||
description="useful for when you want to generate a piece of singing voice from its description."
|
||||
"The input to this tool should be a comma seperated string of three, representing the text sequence and its corresponding note and duration sequence."),
|
||||
Tool(name="Generate Audio From The Image", func=self.i2a.inference,
|
||||
description="useful for when you want to generate an audio based on an image."
|
||||
"The input to this tool should be a string, representing the image_path. "),
|
||||
Tool(name="Get Audio Transcription", func=self.asr.inference,
|
||||
description="useful for when you want to know the text content corresponding to this audio, receives audio_path as input."
|
||||
"The input to this tool should be a string, representing the audio_path.")
|
||||
]
|
||||
]
|
||||
self.agent = initialize_agent(
|
||||
self.tools,
|
||||
self.llm,
|
||||
@@ -457,31 +477,20 @@ class ConversationBot:
|
||||
print("======>Previous memory:\n %s" % self.agent.memory)
|
||||
self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
|
||||
res = self.agent({"input": text})
|
||||
tool = res['intermediate_steps'][0][0].tool
|
||||
if tool == "Generate Image From User Input Text":
|
||||
print("======>Current memory:\n %s" % self.agent.memory)
|
||||
response = re.sub('(image/\S*png)', lambda m: f'})*{m.group(0)}*', res['output'])
|
||||
state = state + [(text, response)]
|
||||
print("Outputs:", state)
|
||||
return state, state, None
|
||||
print("======>Current memory:\n %s" % self.agent.memory)
|
||||
audio_filename = res['intermediate_steps'][0][1]
|
||||
response = re.sub('(image/\S*png)', lambda m: f'})*{m.group(0)}*', res['output'])
|
||||
#response = res['output'] + f"<audio src=audio_filename controls=controls></audio>"
|
||||
state = state + [(text, response)]
|
||||
print("Outputs:", state)
|
||||
return state, state, temp_audio_filename
|
||||
|
||||
def run_audio(self, audio, state, txt):
|
||||
#print(audio.type)
|
||||
print("===============Running run_audio =============")
|
||||
print("Inputs:", audio, state)
|
||||
print("======>Previous memory:\n %s" % self.agent.memory)
|
||||
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
||||
print("======>Auto Resize Audio...")
|
||||
audio_load = whisper.load_audio(audio.name)
|
||||
soundfile.write(audio_filename, audio_load, samplerate = 16000)
|
||||
global temp_audio_filename
|
||||
temp_audio_filename = audio_filename
|
||||
description = self.asr.inference(audio_filename)
|
||||
Human_prompt = "\nHuman: provide an audio named {}. The description is: {}. This information helps you to understand this audio, but you should use tools to finish following tasks, " \
|
||||
"rather than directly imagine from my description. If you understand, say \"Received\". \n".format(audio_filename, description)
|
||||
AI_prompt = "Received. "
|
||||
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
||||
state = state + [(f"*{audio_filename}*", AI_prompt)]
|
||||
print("Outputs:", state)
|
||||
return state, state, txt + ' ' + audio_filename + ' ', temp_audio_filename
|
||||
return state, state, audio_filename
|
||||
|
||||
def run_image_or_audio(self, file, state, txt):
|
||||
file_type = file.name[-3:]
|
||||
@@ -493,16 +502,15 @@ class ConversationBot:
|
||||
print("======>Auto Resize Audio...")
|
||||
audio_load = whisper.load_audio(file.name)
|
||||
soundfile.write(audio_filename, audio_load, samplerate = 16000)
|
||||
global temp_audio_filename
|
||||
temp_audio_filename = audio_filename
|
||||
description = self.asr.inference(audio_filename)
|
||||
Human_prompt = "\nHuman: provide an audio named {}. The description is: {}. This information helps you to understand this audio, but you should use tools to finish following tasks, " \
|
||||
"rather than directly imagine from my description. If you understand, say \"Received\". \n".format(audio_filename, description)
|
||||
AI_prompt = "Received. "
|
||||
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
||||
#state = state + [(f"<audio src=audio_filename controls=controls></audio>*{audio_filename}*", AI_prompt)]
|
||||
state = state + [(f"*{audio_filename}*", AI_prompt)]
|
||||
print("Outputs:", state)
|
||||
return state, state, txt + ' ' + audio_filename + ' ', temp_audio_filename
|
||||
return state, state, txt + ' ' + audio_filename + ' ', audio_filename
|
||||
else:
|
||||
print("===============Running run_image =============")
|
||||
print("Inputs:", file, state)
|
||||
@@ -525,7 +533,7 @@ class ConversationBot:
|
||||
print("======>Current memory:\n %s" % self.agent.memory)
|
||||
state = state + [(f"*{image_filename}*", AI_prompt)]
|
||||
print("Outputs:", state)
|
||||
return state, state, txt + ' ' + image_filename + ' ', temp_audio_filename
|
||||
return state, state, txt + ' ' + image_filename + ' ', None
|
||||
|
||||
if __name__ == '__main__':
|
||||
bot = ConversationBot()
|
||||
@@ -545,10 +553,8 @@ if __name__ == '__main__':
|
||||
outaudio = gr.Audio()
|
||||
txt.submit(bot.run_text, [txt, state], [chatbot, state, outaudio])
|
||||
txt.submit(lambda: "", None, txt)
|
||||
#btn.upload(bot.run_image, [btn, state, txt], [chatbot, state, txt])
|
||||
btn.upload(bot.run_image_or_audio, [btn, state, txt], [chatbot, state, txt, outaudio])
|
||||
clear.click(bot.memory.clear)
|
||||
clear.click(lambda: [], None, chatbot)
|
||||
clear.click(lambda: [], None, state)
|
||||
#clear.click(lambda: [], None, outaudio)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7862, share=True)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
||||
|
||||
Reference in New Issue
Block a user