mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-21 14:19:39 +01:00
Update audio-chatgpt.py
This commit is contained in:
269
audio-chatgpt.py
269
audio-chatgpt.py
@@ -25,10 +25,13 @@ import einops
|
|||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
import random
|
import random
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.data.extract_mel_spectrogram import TRANSFORMS_16000
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from vocoder.hifigan.modules import VocoderHifigan
|
from vocoder.hifigan.modules import VocoderHifigan
|
||||||
|
from Make_An_Audio_img.vocoder.bigvgan.models import VocoderBigVGAN
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from wav_evaluation.models.CLAPWrapper import CLAPWrapper
|
from wav_evaluation.models.CLAPWrapper import CLAPWrapper
|
||||||
|
from Make_An_Audio_img.ldm.util import instantiate_from_config as instantiate_from_config_make_an_audio_img
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
AUDIO_CHATGPT_PREFIX = """Audio ChatGPT
|
AUDIO_CHATGPT_PREFIX = """Audio ChatGPT
|
||||||
@@ -68,7 +71,7 @@ Thought: Do I need to use a tool? {agent_scratchpad}"""
|
|||||||
|
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
temp_audio_filename = "audio/c00d9240.wav"
|
temp_audio_filename = "audio/c00d9240.wav"
|
||||||
# model = whisper.load_model("base")
|
|
||||||
|
|
||||||
def cut_dialogue_history(history_memory, keep_last_n_words = 500):
|
def cut_dialogue_history(history_memory, keep_last_n_words = 500):
|
||||||
tokens = history_memory.split()
|
tokens = history_memory.split()
|
||||||
@@ -111,11 +114,21 @@ def initialize_model(config, ckpt, device):
|
|||||||
model.cond_stage_model.to(model.device)
|
model.cond_stage_model.to(model.device)
|
||||||
model.cond_stage_model.device = model.device
|
model.cond_stage_model.device = model.device
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
clap_model = CLAPWrapper('useful_ckpts/CLAP/CLAP_weights_2022.pth','useful_ckpts/CLAP/config.yml',use_cuda=torch.cuda.is_available())
|
clap_model = CLAPWrapper('useful_ckpts/CLAP/CLAP_weights_2022.pth','useful_ckpts/CLAP/config.yml',use_cuda=torch.cuda.is_available())
|
||||||
|
|
||||||
|
def initialize_model_img(config, ckpt, device):
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
model = instantiate_from_config_make_an_audio_img(config.model)
|
||||||
|
model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
model.cond_stage_model.to(model.device)
|
||||||
|
model.cond_stage_model.device = model.device
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
return sampler
|
||||||
|
|
||||||
def select_best_audio(prompt,wav_list):
|
def select_best_audio(prompt,wav_list):
|
||||||
text_embeddings = clap_model.get_text_embeddings([prompt])
|
text_embeddings = clap_model.get_text_embeddings([prompt])
|
||||||
score_list = []
|
score_list = []
|
||||||
@@ -176,6 +189,18 @@ class T2I:
|
|||||||
print(f"Processed T2I.run, text: {text}, image_filename: {image_filename}")
|
print(f"Processed T2I.run, text: {text}, image_filename: {image_filename}")
|
||||||
return image_filename
|
return image_filename
|
||||||
|
|
||||||
|
class ImageCaptioning:
|
||||||
|
def __init__(self, device):
|
||||||
|
print("Initializing ImageCaptioning to %s" % device)
|
||||||
|
self.device = device
|
||||||
|
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||||
|
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(self.device)
|
||||||
|
|
||||||
|
def inference(self, image_path):
|
||||||
|
inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device)
|
||||||
|
out = self.model.generate(**inputs)
|
||||||
|
captions = self.processor.decode(out[0], skip_special_tokens=True)
|
||||||
|
return captions
|
||||||
|
|
||||||
class T2A:
|
class T2A:
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
@@ -225,6 +250,156 @@ class T2A:
|
|||||||
print(f"Processed T2I.run, text: {text}, audio_filename: {audio_filename}")
|
print(f"Processed T2I.run, text: {text}, audio_filename: {audio_filename}")
|
||||||
return audio_filename
|
return audio_filename
|
||||||
|
|
||||||
|
class I2A:
|
||||||
|
def __init__(self, device):
|
||||||
|
print("Initializing Make-An-Audio-Image to %s" % device)
|
||||||
|
self.device = device
|
||||||
|
self.sampler = initialize_model_img('Make_An_Audio_img/configs/img_to_audio/img2audio_args.yaml', 'Make_An_Audio_img/useful_ckpts/ta54_epoch=000216.ckpt', device=device)
|
||||||
|
self.vocoder = VocoderBigVGAN('Make_An_Audio_img/vocoder/logs/bigv16k53w',device=device)
|
||||||
|
def img2audio(self, image, seed = 55, scale = 3, ddim_steps = 100, W = 624, H = 80):
|
||||||
|
n_samples = 1 # only support 1 sample
|
||||||
|
prng = np.random.RandomState(seed)
|
||||||
|
start_code = prng.randn(n_samples, self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
|
||||||
|
start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
|
||||||
|
uc = self.sampler.model.get_learned_conditioning(n_samples * [""])
|
||||||
|
#image = Image.fromarray(image)
|
||||||
|
image = Image.open(image)
|
||||||
|
image = self.sampler.model.cond_stage_model.preprocess(image).unsqueeze(0)
|
||||||
|
image_embedding = self.sampler.model.cond_stage_model.forward_img(image)
|
||||||
|
c = image_embedding.repeat(n_samples, 1, 1)# shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding
|
||||||
|
shape = [self.sampler.model.first_stage_model.embed_dim, H//8, W//8] # (z_dim, 80//2^x, 848//2^x)
|
||||||
|
samples_ddim, _ = self.sampler.sample(S=ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=n_samples,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
x_T=start_code)
|
||||||
|
|
||||||
|
x_samples_ddim = self.sampler.model.decode_first_stage(samples_ddim)
|
||||||
|
x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) # [0, 1]
|
||||||
|
wav_list = []
|
||||||
|
for idx,spec in enumerate(x_samples_ddim):
|
||||||
|
wav = self.vocoder.vocode(spec)
|
||||||
|
wav_list.append((SAMPLE_RATE,wav))
|
||||||
|
best_wav = wav_list[0]
|
||||||
|
return best_wav
|
||||||
|
def inference(self, image, seed = 55, scale = 3, ddim_steps = 100, W = 624, H = 80):
|
||||||
|
global temp_audio_filename
|
||||||
|
melbins,mel_len = 80,624
|
||||||
|
with torch.no_grad():
|
||||||
|
result = self.img2audio(
|
||||||
|
image=image,
|
||||||
|
H=melbins,
|
||||||
|
W=mel_len
|
||||||
|
)
|
||||||
|
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
||||||
|
temp_audio_filename = audio_filename
|
||||||
|
soundfile.write(audio_filename, result[1], samplerate = 16000)
|
||||||
|
print(f"Processed I2a.run, image_filename: {image}, audio_filename: {audio_filename}")
|
||||||
|
return audio_filename
|
||||||
|
# need to debug
|
||||||
|
class Inpaint:
|
||||||
|
def __init__(self, device):
|
||||||
|
print("Initializing Make-An-Audio-inpaint to %s" % device)
|
||||||
|
self.device = device
|
||||||
|
self.sampler = initialize_model('Make_An_Audio_inpaint/configs/inpaint/txt2audio_args.yaml', 'Make_An_Audio_inpaint/useful_ckpts/inpaint7_epoch00047.ckpt')
|
||||||
|
self.vocoder = VocoderBigVGAN('./vocoder/logs/bigv16k53w',device=device)
|
||||||
|
def make_batch_sd(mel, mask, num_samples=1):
|
||||||
|
|
||||||
|
mel = torch.from_numpy(mel)[None,None,...].to(dtype=torch.float32)
|
||||||
|
mask = torch.from_numpy(mask)[None,None,...].to(dtype=torch.float32)
|
||||||
|
masked_mel = (1 - mask) * mel
|
||||||
|
|
||||||
|
mel = mel * 2 - 1
|
||||||
|
mask = mask * 2 - 1
|
||||||
|
masked_mel = masked_mel * 2 -1
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"mel": repeat(mel.to(device=self.device), "1 ... -> n ...", n=num_samples),
|
||||||
|
"mask": repeat(mask.to(device=self.device), "1 ... -> n ...", n=num_samples),
|
||||||
|
"masked_mel": repeat(masked_mel.to(device=self.device), "1 ... -> n ...", n=num_samples),
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
def gen_mel(input_audio):
|
||||||
|
sr,ori_wav = input_audio
|
||||||
|
print(sr,ori_wav.shape,ori_wav)
|
||||||
|
|
||||||
|
ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0 # order='C'是以C语言格式存储,不用管
|
||||||
|
if len(ori_wav.shape)==2:# stereo
|
||||||
|
ori_wav = librosa.to_mono(ori_wav.T)# gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len)
|
||||||
|
print(sr,ori_wav.shape,ori_wav)
|
||||||
|
ori_wav = librosa.resample(ori_wav,orig_sr = sr,target_sr = SAMPLE_RATE)
|
||||||
|
|
||||||
|
mel_len,hop_size = 848,256
|
||||||
|
input_len = mel_len * hop_size
|
||||||
|
if len(ori_wav) < input_len:
|
||||||
|
input_wav = np.pad(ori_wav,(0,mel_len*hop_size),constant_values=0)
|
||||||
|
else:
|
||||||
|
input_wav = ori_wav[:input_len]
|
||||||
|
|
||||||
|
mel = TRANSFORMS_16000(input_wav)
|
||||||
|
return mel
|
||||||
|
def show_mel_fn(input_audio):
|
||||||
|
crop_len = 500 # the full mel cannot be showed due to gradio's Image bug when using tool='sketch'
|
||||||
|
crop_mel = self.gen_mel(input_audio)[:,:crop_len]
|
||||||
|
color_mel = cmap_transform(crop_mel)
|
||||||
|
return Image.fromarray((color_mel*255).astype(np.uint8))
|
||||||
|
def inpaint(batch, seed, ddim_steps, num_samples=1, W=512, H=512):
|
||||||
|
model = self.sampler.model
|
||||||
|
|
||||||
|
prng = np.random.RandomState(seed)
|
||||||
|
start_code = prng.randn(num_samples, model.first_stage_model.embed_dim, H // 8, W // 8)
|
||||||
|
start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
c = model.get_first_stage_encoding(model.encode_first_stage(batch["masked_mel"]))
|
||||||
|
cc = torch.nn.functional.interpolate(batch["mask"],
|
||||||
|
size=c.shape[-2:])
|
||||||
|
c = torch.cat((c, cc), dim=1) # (b,c+1,h,w) 1 is mask
|
||||||
|
|
||||||
|
shape = (c.shape[1]-1,)+c.shape[2:]
|
||||||
|
samples_ddim, _ = self.sampler.sample(S=ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=c.shape[0],
|
||||||
|
shape=shape,
|
||||||
|
verbose=False)
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
|
||||||
|
|
||||||
|
mask = batch["mask"]# [-1,1]
|
||||||
|
mel = torch.clamp((batch["mel"]+1.0)/2.0,min=0.0, max=1.0)
|
||||||
|
mask = torch.clamp((batch["mask"]+1.0)/2.0,min=0.0, max=1.0)
|
||||||
|
predicted_mel = torch.clamp((x_samples_ddim+1.0)/2.0,min=0.0, max=1.0)
|
||||||
|
inpainted = (1-mask)*mel+mask*predicted_mel
|
||||||
|
inpainted = inpainted.cpu().numpy().squeeze()
|
||||||
|
inapint_wav = self.vocoder.vocode(inpainted)
|
||||||
|
|
||||||
|
return inpainted, inapint_wav
|
||||||
|
def predict(input_audio,mel_and_mask,ddim_steps,seed):
|
||||||
|
show_mel = np.array(mel_and_mask['image'].convert("L"))/255 # 由于展示的mel只展示了一部分,所以需要重新从音频生成mel
|
||||||
|
mask = np.array(mel_and_mask["mask"].convert("L"))/255
|
||||||
|
|
||||||
|
mel_bins,mel_len = 80,848
|
||||||
|
|
||||||
|
input_mel = self.gen_mel(input_audio)[:,:mel_len]# 由于展示的mel只展示了一部分,所以需要重新从音频生成mel
|
||||||
|
mask = np.pad(mask,((0,0),(0,mel_len-mask.shape[1])),mode='constant',constant_values=0)# 将mask填充到原来的mel的大小
|
||||||
|
print(mask.shape,input_mel.shape)
|
||||||
|
with torch.no_grad():
|
||||||
|
batch = make_batch_sd(input_mel,mask,device,num_samples=1)
|
||||||
|
inpainted,gen_wav = self.inpaint(
|
||||||
|
batch=batch,
|
||||||
|
seed=seed,
|
||||||
|
ddim_steps=ddim_steps,
|
||||||
|
num_samples=1,
|
||||||
|
H=mel_bins, W=mel_len
|
||||||
|
)
|
||||||
|
inpainted = inpainted[:,:show_mel.shape[1]]
|
||||||
|
color_mel = cmap_transform(inpainted)
|
||||||
|
input_len = int(input_audio[1].shape[0] * SAMPLE_RATE / input_audio[0])
|
||||||
|
gen_wav = (gen_wav * 32768).astype(np.int16)[:input_len]
|
||||||
|
return Image.fromarray((color_mel*255).astype(np.uint8)),(SAMPLE_RATE,gen_wav)
|
||||||
|
|
||||||
class ASR:
|
class ASR:
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
print("Initializing Whisper to %s" % device)
|
print("Initializing Whisper to %s" % device)
|
||||||
@@ -244,17 +419,25 @@ class ConversationBot:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
print("Initializing AudioChatGPT")
|
print("Initializing AudioChatGPT")
|
||||||
self.llm = OpenAI(temperature=0)
|
self.llm = OpenAI(temperature=0)
|
||||||
self.t2i = T2I(device="cuda:2")
|
self.t2i = T2I(device="cuda:1")
|
||||||
self.t2a = T2A(device="cuda:2")
|
self.i2t = ImageCaptioning(device="cuda:1")
|
||||||
self.asr = ASR(device="cuda:2")
|
self.t2a = T2A(device="cuda:0")
|
||||||
|
self.i2a = I2A(device="cuda:1")
|
||||||
|
self.asr = ASR(device="cuda:1")
|
||||||
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
||||||
self.tools = [
|
self.tools = [
|
||||||
Tool(name="Generate Image From User Input Text", func=self.t2i.inference,
|
Tool(name="Generate Image From User Input Text", func=self.t2i.inference,
|
||||||
description="useful for when you want to generate an image from a user input text and it saved it to a file. like: generate an image of an object or something, or generate an image that includes some objects. "
|
description="useful for when you want to generate an image from a user input text and it saved it to a file. like: generate an image of an object or something, or generate an image that includes some objects. "
|
||||||
"The input to this tool should be a string, representing the text used to generate image. "),
|
"The input to this tool should be a string, representing the text used to generate image. "),
|
||||||
|
Tool(name="Get Photo Description", func=self.i2t.inference,
|
||||||
|
description="useful for when you want to know what is inside the photo. receives image_path as input. "
|
||||||
|
"The input to this tool should be a string, representing the image_path. "),
|
||||||
Tool(name="Generate Audio From User Input Text", func=self.t2a.inference,
|
Tool(name="Generate Audio From User Input Text", func=self.t2a.inference,
|
||||||
description="useful for when you want to generate an audio from a user input text and it saved it to a file."
|
description="useful for when you want to generate an audio from a user input text and it saved it to a file."
|
||||||
"The input to this tool should be a string, representing the text used to generate audio."),
|
"The input to this tool should be a string, representing the text used to generate audio."),
|
||||||
|
Tool(name="Generate Audio From The Image", func=self.i2a.inference,
|
||||||
|
description="useful for when you want to generate an audio based on an image."
|
||||||
|
"The input to this tool should be a string, representing the image_path. "),
|
||||||
Tool(name="Get Audio Transcription", func=self.asr.inference,
|
Tool(name="Get Audio Transcription", func=self.asr.inference,
|
||||||
description="useful for when you want to know the text content corresponding to this audio, receives audio_path as input."
|
description="useful for when you want to know the text content corresponding to this audio, receives audio_path as input."
|
||||||
"The input to this tool should be a string, representing the audio_path.")
|
"The input to this tool should be a string, representing the audio_path.")
|
||||||
@@ -281,6 +464,7 @@ class ConversationBot:
|
|||||||
return state, state, temp_audio_filename
|
return state, state, temp_audio_filename
|
||||||
|
|
||||||
def run_audio(self, audio, state, txt):
|
def run_audio(self, audio, state, txt):
|
||||||
|
#print(audio.type)
|
||||||
print("===============Running run_audio =============")
|
print("===============Running run_audio =============")
|
||||||
print("Inputs:", audio, state)
|
print("Inputs:", audio, state)
|
||||||
print("======>Previous memory:\n %s" % self.agent.memory)
|
print("======>Previous memory:\n %s" % self.agent.memory)
|
||||||
@@ -299,30 +483,49 @@ class ConversationBot:
|
|||||||
print("Outputs:", state)
|
print("Outputs:", state)
|
||||||
return state, state, txt + ' ' + audio_filename + ' ', temp_audio_filename
|
return state, state, txt + ' ' + audio_filename + ' ', temp_audio_filename
|
||||||
|
|
||||||
# def run_image(self, image, state, txt):
|
def run_image_or_audio(self, file, state, txt):
|
||||||
# print("===============Running run_image =============")
|
file_type = file.name[-3:]
|
||||||
# print("Inputs:", image, state)
|
if file_type == "wav":
|
||||||
# print("======>Previous memory:\n %s" % self.agent.memory)
|
print("===============Running run_audio =============")
|
||||||
# image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
|
print("Inputs:", file, state)
|
||||||
# print("======>Auto Resize Image...")
|
print("======>Previous memory:\n %s" % self.agent.memory)
|
||||||
# img = Image.open(image.name)
|
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
||||||
# width, height = img.size
|
print("======>Auto Resize Audio...")
|
||||||
# ratio = min(512 / width, 512 / height)
|
audio_load = whisper.load_audio(file.name)
|
||||||
# width_new, height_new = (round(width * ratio), round(height * ratio))
|
soundfile.write(audio_filename, audio_load, samplerate = 16000)
|
||||||
# img = img.resize((width_new, height_new))
|
global temp_audio_filename
|
||||||
# img = img.convert('RGB')
|
temp_audio_filename = audio_filename
|
||||||
# img.save(image_filename, "PNG")
|
description = self.asr.inference(audio_filename)
|
||||||
# print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
Human_prompt = "\nHuman: provide an audio named {}. The description is: {}. This information helps you to understand this audio, but you should use tools to finish following tasks, " \
|
||||||
# description = self.i2t.inference(image_filename)
|
"rather than directly imagine from my description. If you understand, say \"Received\". \n".format(audio_filename, description)
|
||||||
# Human_prompt = "\nHuman: provide a figure named {}. The description is: {}. This information helps you to understand this image, but you should use tools to finish following tasks, " \
|
AI_prompt = "Received. "
|
||||||
# "rather than directly imagine from my description. If you understand, say \"Received\". \n".format(image_filename, description)
|
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
||||||
# AI_prompt = "Received. "
|
state = state + [(f"*{audio_filename}*", AI_prompt)]
|
||||||
# self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
print("Outputs:", state)
|
||||||
# print("======>Current memory:\n %s" % self.agent.memory)
|
return state, state, txt + ' ' + audio_filename + ' ', temp_audio_filename
|
||||||
# state = state + [(f"*{image_filename}*", AI_prompt)]
|
else:
|
||||||
# print("Outputs:", state)
|
print("===============Running run_image =============")
|
||||||
# return state, state, txt + ' ' + image_filename + ' '
|
print("Inputs:", file, state)
|
||||||
|
print("======>Previous memory:\n %s" % self.agent.memory)
|
||||||
|
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
|
||||||
|
print("======>Auto Resize Image...")
|
||||||
|
img = Image.open(file.name)
|
||||||
|
width, height = img.size
|
||||||
|
ratio = min(512 / width, 512 / height)
|
||||||
|
width_new, height_new = (round(width * ratio), round(height * ratio))
|
||||||
|
img = img.resize((width_new, height_new))
|
||||||
|
img = img.convert('RGB')
|
||||||
|
img.save(image_filename, "PNG")
|
||||||
|
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
||||||
|
description = self.i2t.inference(image_filename)
|
||||||
|
Human_prompt = "\nHuman: provide a figure named {}. The description is: {}. This information helps you to understand this image, but you should use tools to finish following tasks, " \
|
||||||
|
"rather than directly imagine from my description. If you understand, say \"Received\". \n".format(image_filename, description)
|
||||||
|
AI_prompt = "Received. "
|
||||||
|
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
||||||
|
print("======>Current memory:\n %s" % self.agent.memory)
|
||||||
|
state = state + [(f"*{image_filename}*", AI_prompt)]
|
||||||
|
print("Outputs:", state)
|
||||||
|
return state, state, txt + ' ' + image_filename + ' ', temp_audio_filename
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
bot = ConversationBot()
|
bot = ConversationBot()
|
||||||
@@ -337,15 +540,15 @@ if __name__ == '__main__':
|
|||||||
with gr.Column(scale=0.15, min_width=0):
|
with gr.Column(scale=0.15, min_width=0):
|
||||||
clear = gr.Button("Clear️")
|
clear = gr.Button("Clear️")
|
||||||
with gr.Column(scale=0.15, min_width=0):
|
with gr.Column(scale=0.15, min_width=0):
|
||||||
btn = gr.UploadButton("Upload", file_types=["audio"])
|
btn = gr.UploadButton("Upload", file_types=["image","audio"])
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
outaudio = gr.Audio()
|
outaudio = gr.Audio()
|
||||||
txt.submit(bot.run_text, [txt, state], [chatbot, state, outaudio])
|
txt.submit(bot.run_text, [txt, state], [chatbot, state, outaudio])
|
||||||
txt.submit(lambda: "", None, txt)
|
txt.submit(lambda: "", None, txt)
|
||||||
#btn.upload(bot.run_image, [btn, state, txt], [chatbot, state, txt])
|
#btn.upload(bot.run_image, [btn, state, txt], [chatbot, state, txt])
|
||||||
btn.upload(bot.run_audio, [btn, state, txt], [chatbot, state, txt, outaudio])
|
btn.upload(bot.run_image_or_audio, [btn, state, txt], [chatbot, state, txt, outaudio])
|
||||||
clear.click(bot.memory.clear)
|
clear.click(bot.memory.clear)
|
||||||
clear.click(lambda: [], None, chatbot)
|
clear.click(lambda: [], None, chatbot)
|
||||||
clear.click(lambda: [], None, state)
|
clear.click(lambda: [], None, state)
|
||||||
clear.click(lambda: [], None, outaudio)
|
#clear.click(lambda: [], None, outaudio)
|
||||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
demo.launch(server_name="0.0.0.0", server_port=7862, share=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user