mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 20:07:49 +01:00
Merge remote-tracking branch 'upstream/dev' into playwright
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import base64
|
||||
|
||||
import asyncio
|
||||
from aiocache import cached
|
||||
@@ -10,6 +12,7 @@ import json
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import ast
|
||||
|
||||
from uuid import uuid4
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -55,6 +58,7 @@ from open_webui.utils.task import (
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
deep_update,
|
||||
get_message_list,
|
||||
add_or_update_system_message,
|
||||
add_or_update_user_message,
|
||||
@@ -69,6 +73,7 @@ from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.tasks import create_task
|
||||
|
||||
from open_webui.config import (
|
||||
CACHE_DIR,
|
||||
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
DEFAULT_CODE_INTERPRETER_PROMPT,
|
||||
)
|
||||
@@ -180,7 +185,7 @@ async def chat_completion_filter_functions_handler(request, body, model, extra_p
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
request: Request, body: dict, user: UserModel, models, extra_params: dict
|
||||
request: Request, body: dict, user: UserModel, models, tools
|
||||
) -> tuple[dict, dict]:
|
||||
async def get_content_from_response(response) -> Optional[str]:
|
||||
content = None
|
||||
@@ -215,35 +220,15 @@ async def chat_completion_tools_handler(
|
||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||
}
|
||||
|
||||
# If tool_ids field is present, call the functions
|
||||
metadata = body.get("metadata", {})
|
||||
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
log.debug(f"{tool_ids=}")
|
||||
if not tool_ids:
|
||||
return body, {}
|
||||
|
||||
skip_files = False
|
||||
sources = []
|
||||
|
||||
task_model_id = get_task_model_id(
|
||||
body["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
tools = get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models[task_model_id],
|
||||
"__messages__": body["messages"],
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
log.info(f"{tools=}")
|
||||
|
||||
skip_files = False
|
||||
sources = []
|
||||
|
||||
specs = [tool["spec"] for tool in tools.values()]
|
||||
tools_specs = json.dumps(specs)
|
||||
@@ -278,6 +263,8 @@ async def chat_completion_tools_handler(
|
||||
result = json.loads(content)
|
||||
|
||||
async def tool_call_handler(tool_call):
|
||||
nonlocal skip_files
|
||||
|
||||
log.debug(f"{tool_call=}")
|
||||
|
||||
tool_function_name = tool_call.get("name", None)
|
||||
@@ -418,7 +405,7 @@ async def chat_web_search_handler(
|
||||
},
|
||||
}
|
||||
)
|
||||
return
|
||||
return form_data
|
||||
|
||||
searchQuery = queries[0]
|
||||
|
||||
@@ -641,7 +628,9 @@ async def chat_completion_files_handler(
|
||||
lambda: get_sources_from_files(
|
||||
files=files,
|
||||
queries=queries,
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
@@ -693,6 +682,7 @@ def apply_params_to_form_data(form_data, model):
|
||||
|
||||
|
||||
async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
form_data = apply_params_to_form_data(form_data, model)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
|
||||
@@ -715,6 +705,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
models = request.app.state.MODELS
|
||||
task_model_id = get_task_model_id(
|
||||
form_data["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
events = []
|
||||
sources = []
|
||||
@@ -799,13 +795,41 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
}
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_tools_handler(
|
||||
request, form_data, user, models, extra_params
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
log.debug(f"{tool_ids=}")
|
||||
|
||||
if tool_ids:
|
||||
# If tool_ids field is present, then get the tools
|
||||
tools = get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models[task_model_id],
|
||||
"__messages__": form_data["messages"],
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
sources.extend(flags.get("sources", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.info(f"{tools=}")
|
||||
|
||||
if metadata.get("function_calling") == "native":
|
||||
# If the function calling is native, then call the tools function calling handler
|
||||
metadata["tools"] = tools
|
||||
form_data["tools"] = [
|
||||
{"type": "function", "function": tool.get("spec", {})}
|
||||
for tool in tools.values()
|
||||
]
|
||||
else:
|
||||
# If the function calling is not native, then call the tools function calling handler
|
||||
try:
|
||||
form_data, flags = await chat_completion_tools_handler(
|
||||
request, form_data, user, models, tools
|
||||
)
|
||||
sources.extend(flags.get("sources", []))
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_files_handler(request, form_data, user)
|
||||
@@ -821,11 +845,11 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
if "document" in source:
|
||||
for doc_idx, doc_context in enumerate(source["document"]):
|
||||
metadata = source.get("metadata")
|
||||
doc_metadata = source.get("metadata")
|
||||
doc_source_id = None
|
||||
|
||||
if metadata:
|
||||
doc_source_id = metadata[doc_idx].get("source", source_id)
|
||||
if doc_metadata:
|
||||
doc_source_id = doc_metadata[doc_idx].get("source", source_id)
|
||||
|
||||
if source_id:
|
||||
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
@@ -882,7 +906,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
}
|
||||
)
|
||||
|
||||
return form_data, events
|
||||
return form_data, metadata, events
|
||||
|
||||
|
||||
async def process_chat_response(
|
||||
@@ -1100,6 +1124,40 @@ async def process_chat_response(
|
||||
for block in content_blocks:
|
||||
if block["type"] == "text":
|
||||
content = f"{content}{block['content'].strip()}\n"
|
||||
elif block["type"] == "tool_calls":
|
||||
attributes = block.get("attributes", {})
|
||||
|
||||
block_content = block.get("content", [])
|
||||
results = block.get("results", [])
|
||||
|
||||
if results:
|
||||
|
||||
result_display_content = ""
|
||||
|
||||
for result in results:
|
||||
tool_call_id = result.get("tool_call_id", "")
|
||||
tool_name = ""
|
||||
|
||||
for tool_call in block_content:
|
||||
if tool_call.get("id", "") == tool_call_id:
|
||||
tool_name = tool_call.get("function", {}).get(
|
||||
"name", ""
|
||||
)
|
||||
break
|
||||
|
||||
result_display_content = f"{result_display_content}\n> {tool_name}: {result.get('content', '')}"
|
||||
|
||||
if not raw:
|
||||
content = f'{content}\n<details type="tool_calls" done="true" content="{html.escape(json.dumps(block_content))}" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n{result_display_content}\n</details>\n'
|
||||
else:
|
||||
tool_calls_display_content = ""
|
||||
|
||||
for tool_call in block_content:
|
||||
tool_calls_display_content = f"{tool_calls_display_content}\n> Executing {tool_call.get('function', {}).get('name', '')}"
|
||||
|
||||
if not raw:
|
||||
content = f'{content}\n<details type="tool_calls" done="false" content="{html.escape(json.dumps(block_content))}">\n<summary>Tool Executing...</summary>\n{tool_calls_display_content}\n</details>\n'
|
||||
|
||||
elif block["type"] == "reasoning":
|
||||
reasoning_display_content = "\n".join(
|
||||
(f"> {line}" if not line.startswith(">") else line)
|
||||
@@ -1108,16 +1166,16 @@ async def process_chat_response(
|
||||
|
||||
reasoning_duration = block.get("duration", None)
|
||||
|
||||
if reasoning_duration:
|
||||
if reasoning_duration is not None:
|
||||
if raw:
|
||||
content = f'{content}<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
else:
|
||||
content = f'{content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
else:
|
||||
if raw:
|
||||
content = f'{content}<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
else:
|
||||
content = f'{content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
|
||||
elif block["type"] == "code_interpreter":
|
||||
attributes = block.get("attributes", {})
|
||||
@@ -1128,20 +1186,20 @@ async def process_chat_response(
|
||||
output = html.escape(json.dumps(output))
|
||||
|
||||
if raw:
|
||||
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
|
||||
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
|
||||
else:
|
||||
content = f'{content}<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
content = f'{content}\n<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
else:
|
||||
if raw:
|
||||
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
|
||||
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
|
||||
else:
|
||||
content = f'{content}<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
content = f'{content}\n<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
|
||||
else:
|
||||
block_content = str(block["content"]).strip()
|
||||
content = f"{content}{block['type']}: {block_content}\n"
|
||||
|
||||
return content
|
||||
return content.strip()
|
||||
|
||||
def tag_content_handler(content_type, tags, content, content_blocks):
|
||||
end_flag = False
|
||||
@@ -1149,6 +1207,8 @@ async def process_chat_response(
|
||||
def extract_attributes(tag_content):
|
||||
"""Extract attributes from a tag if they exist."""
|
||||
attributes = {}
|
||||
if not tag_content: # Ensure tag_content is not None
|
||||
return attributes
|
||||
# Match attributes in the format: key="value" (ignores single quotes for simplicity)
|
||||
matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content)
|
||||
for key, value in matches:
|
||||
@@ -1158,17 +1218,35 @@ async def process_chat_response(
|
||||
if content_blocks[-1]["type"] == "text":
|
||||
for tag in tags:
|
||||
# Match start tag e.g., <tag> or <tag attr="value">
|
||||
start_tag_pattern = rf"<{tag}(.*?)>"
|
||||
start_tag_pattern = rf"<{tag}(\s.*?)?>"
|
||||
match = re.search(start_tag_pattern, content)
|
||||
if match:
|
||||
# Extract attributes in the tag (if present)
|
||||
attributes = extract_attributes(match.group(1))
|
||||
attr_content = (
|
||||
match.group(1) if match.group(1) else ""
|
||||
) # Ensure it's not None
|
||||
attributes = extract_attributes(
|
||||
attr_content
|
||||
) # Extract attributes safely
|
||||
|
||||
# Capture everything before and after the matched tag
|
||||
before_tag = content[
|
||||
: match.start()
|
||||
] # Content before opening tag
|
||||
after_tag = content[
|
||||
match.end() :
|
||||
] # Content after opening tag
|
||||
|
||||
# Remove the start tag from the currently handling text block
|
||||
content_blocks[-1]["content"] = content_blocks[-1][
|
||||
"content"
|
||||
].replace(match.group(0), "")
|
||||
|
||||
if before_tag:
|
||||
content_blocks[-1]["content"] = before_tag
|
||||
|
||||
if not content_blocks[-1]["content"]:
|
||||
content_blocks.pop()
|
||||
|
||||
# Append the new block
|
||||
content_blocks.append(
|
||||
{
|
||||
@@ -1179,52 +1257,100 @@ async def process_chat_response(
|
||||
"started_at": time.time(),
|
||||
}
|
||||
)
|
||||
|
||||
if after_tag:
|
||||
content_blocks[-1]["content"] = after_tag
|
||||
|
||||
break
|
||||
elif content_blocks[-1]["type"] == content_type:
|
||||
tag = content_blocks[-1]["tag"]
|
||||
# Match end tag e.g., </tag>
|
||||
end_tag_pattern = rf"</{tag}>"
|
||||
|
||||
# Check if the content has the end tag
|
||||
if re.search(end_tag_pattern, content):
|
||||
end_flag = True
|
||||
|
||||
block_content = content_blocks[-1]["content"]
|
||||
# Strip start and end tags from the content
|
||||
start_tag_pattern = rf"<{tag}(.*?)>"
|
||||
block_content = re.sub(
|
||||
start_tag_pattern, "", block_content
|
||||
).strip()
|
||||
block_content = re.sub(
|
||||
end_tag_pattern, "", block_content
|
||||
).strip()
|
||||
|
||||
end_tag_regex = re.compile(end_tag_pattern, re.DOTALL)
|
||||
split_content = end_tag_regex.split(block_content, maxsplit=1)
|
||||
|
||||
# Content inside the tag
|
||||
block_content = (
|
||||
split_content[0].strip() if split_content else ""
|
||||
)
|
||||
|
||||
# Leftover content (everything after `</tag>`)
|
||||
leftover_content = (
|
||||
split_content[1].strip() if len(split_content) > 1 else ""
|
||||
)
|
||||
|
||||
if block_content:
|
||||
end_flag = True
|
||||
content_blocks[-1]["content"] = block_content
|
||||
content_blocks[-1]["ended_at"] = time.time()
|
||||
content_blocks[-1]["duration"] = int(
|
||||
content_blocks[-1]["ended_at"]
|
||||
- content_blocks[-1]["started_at"]
|
||||
)
|
||||
|
||||
# Reset the content_blocks by appending a new text block
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
# Clean processed content
|
||||
content = re.sub(
|
||||
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
|
||||
"",
|
||||
content,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
if content_type != "code_interpreter":
|
||||
if leftover_content:
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": leftover_content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
# Remove the block if content is empty
|
||||
content_blocks.pop()
|
||||
|
||||
if leftover_content:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": leftover_content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
# Clean processed content
|
||||
content = re.sub(
|
||||
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
|
||||
"",
|
||||
content,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
|
||||
return content, content_blocks, end_flag
|
||||
|
||||
message = Chats.get_message_by_id_and_message_id(
|
||||
metadata["chat_id"], metadata["message_id"]
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
content = message.get("content", "") if message else ""
|
||||
content_blocks = [
|
||||
{
|
||||
@@ -1235,9 +1361,18 @@ async def process_chat_response(
|
||||
|
||||
# We might want to disable this by default
|
||||
DETECT_REASONING = True
|
||||
DETECT_CODE_INTERPRETER = True
|
||||
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
|
||||
"code_interpreter", False
|
||||
)
|
||||
|
||||
reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
|
||||
reasoning_tags = [
|
||||
"think",
|
||||
"thinking",
|
||||
"reason",
|
||||
"reasoning",
|
||||
"thought",
|
||||
"Thought",
|
||||
]
|
||||
code_interpreter_tags = ["code_interpreter"]
|
||||
|
||||
try:
|
||||
@@ -1262,6 +1397,8 @@ async def process_chat_response(
|
||||
nonlocal content
|
||||
nonlocal content_blocks
|
||||
|
||||
response_tool_calls = []
|
||||
|
||||
async for line in response.body_iterator:
|
||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
||||
data = line
|
||||
@@ -1294,10 +1431,54 @@ async def process_chat_response(
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
value = choices[0].get("delta", {}).get("content")
|
||||
delta = choices[0].get("delta", {})
|
||||
delta_tool_calls = delta.get("tool_calls", None)
|
||||
|
||||
if delta_tool_calls:
|
||||
for delta_tool_call in delta_tool_calls:
|
||||
tool_call_index = delta_tool_call.get("index")
|
||||
|
||||
if tool_call_index is not None:
|
||||
if (
|
||||
len(response_tool_calls)
|
||||
<= tool_call_index
|
||||
):
|
||||
response_tool_calls.append(
|
||||
delta_tool_call
|
||||
)
|
||||
else:
|
||||
delta_name = delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("name")
|
||||
delta_arguments = delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("arguments")
|
||||
|
||||
if delta_name:
|
||||
response_tool_calls[
|
||||
tool_call_index
|
||||
]["function"]["name"] += delta_name
|
||||
|
||||
if delta_arguments:
|
||||
response_tool_calls[
|
||||
tool_call_index
|
||||
]["function"][
|
||||
"arguments"
|
||||
] += delta_arguments
|
||||
|
||||
value = delta.get("content")
|
||||
|
||||
if value:
|
||||
content = f"{content}{value}"
|
||||
|
||||
if not content_blocks:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
content_blocks[-1]["content"] = (
|
||||
content_blocks[-1]["content"] + value
|
||||
)
|
||||
@@ -1357,14 +1538,46 @@ async def process_chat_response(
|
||||
log.debug("Error: ", e)
|
||||
continue
|
||||
|
||||
# Clean up the last text block
|
||||
if content_blocks[-1]["type"] == "text":
|
||||
content_blocks[-1]["content"] = content_blocks[-1][
|
||||
"content"
|
||||
].strip()
|
||||
if content_blocks:
|
||||
# Clean up the last text block
|
||||
if content_blocks[-1]["type"] == "text":
|
||||
content_blocks[-1]["content"] = content_blocks[-1][
|
||||
"content"
|
||||
].strip()
|
||||
|
||||
if not content_blocks[-1]["content"]:
|
||||
content_blocks.pop()
|
||||
if not content_blocks[-1]["content"]:
|
||||
content_blocks.pop()
|
||||
|
||||
if not content_blocks:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
if response_tool_calls:
|
||||
tool_calls.append(response_tool_calls)
|
||||
|
||||
if response.background:
|
||||
await response.background()
|
||||
|
||||
await stream_body_handler(response)
|
||||
|
||||
MAX_TOOL_CALL_RETRIES = 5
|
||||
tool_call_retries = 0
|
||||
|
||||
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
|
||||
tool_call_retries += 1
|
||||
|
||||
response_tool_calls = tool_calls.pop(0)
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_calls",
|
||||
"content": response_tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
@@ -1375,37 +1588,54 @@ async def process_chat_response(
|
||||
}
|
||||
)
|
||||
|
||||
if response.background:
|
||||
await response.background()
|
||||
tools = metadata.get("tools", {})
|
||||
|
||||
await stream_body_handler(response)
|
||||
results = []
|
||||
for tool_call in response_tool_calls:
|
||||
print("\n\n" + str(tool_call) + "\n\n")
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("function", {}).get("name", "")
|
||||
|
||||
MAX_RETRIES = 5
|
||||
retries = 0
|
||||
|
||||
while (
|
||||
content_blocks[-1]["type"] == "code_interpreter"
|
||||
and retries < MAX_RETRIES
|
||||
):
|
||||
retries += 1
|
||||
log.debug(f"Attempt count: {retries}")
|
||||
|
||||
output = ""
|
||||
try:
|
||||
if content_blocks[-1]["attributes"].get("type") == "code":
|
||||
output = await event_caller(
|
||||
{
|
||||
"type": "execute:python",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": content_blocks[-1]["content"],
|
||||
},
|
||||
}
|
||||
tool_function_params = {}
|
||||
try:
|
||||
# json.loads cannot be used because some models do not produce valid JSON
|
||||
tool_function_params = ast.literal_eval(
|
||||
tool_call.get("function", {}).get("arguments", "{}")
|
||||
)
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
|
||||
tool_result = None
|
||||
|
||||
if tool_name in tools:
|
||||
tool = tools[tool_name]
|
||||
spec = tool.get("spec", {})
|
||||
|
||||
try:
|
||||
required_params = spec.get("parameters", {}).get(
|
||||
"required", []
|
||||
)
|
||||
tool_function = tool["callable"]
|
||||
tool_function_params = {
|
||||
k: v
|
||||
for k, v in tool_function_params.items()
|
||||
if k in required_params
|
||||
}
|
||||
tool_result = await tool_function(
|
||||
**tool_function_params
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = str(e)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
content_blocks[-1]["results"] = results
|
||||
|
||||
content_blocks[-1]["output"] = output
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
@@ -1435,7 +1665,16 @@ async def process_chat_response(
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks, raw=True
|
||||
),
|
||||
"tool_calls": response_tool_calls,
|
||||
},
|
||||
*[
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result["tool_call_id"],
|
||||
"content": result["content"],
|
||||
}
|
||||
for result in results
|
||||
],
|
||||
],
|
||||
},
|
||||
user,
|
||||
@@ -1449,6 +1688,110 @@ async def process_chat_response(
|
||||
log.debug(e)
|
||||
break
|
||||
|
||||
if DETECT_CODE_INTERPRETER:
|
||||
MAX_RETRIES = 5
|
||||
retries = 0
|
||||
|
||||
while (
|
||||
content_blocks[-1]["type"] == "code_interpreter"
|
||||
and retries < MAX_RETRIES
|
||||
):
|
||||
retries += 1
|
||||
log.debug(f"Attempt count: {retries}")
|
||||
|
||||
output = ""
|
||||
try:
|
||||
if content_blocks[-1]["attributes"].get("type") == "code":
|
||||
output = await event_caller(
|
||||
{
|
||||
"type": "execute:python",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": content_blocks[-1]["content"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(output, dict):
|
||||
stdout = output.get("stdout", "")
|
||||
|
||||
if stdout:
|
||||
stdoutLines = stdout.split("\n")
|
||||
for idx, line in enumerate(stdoutLines):
|
||||
if "data:image/png;base64" in line:
|
||||
id = str(uuid4())
|
||||
|
||||
# ensure the path exists
|
||||
os.makedirs(
|
||||
os.path.join(CACHE_DIR, "images"),
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
image_path = os.path.join(
|
||||
CACHE_DIR,
|
||||
f"images/{id}.png",
|
||||
)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(
|
||||
base64.b64decode(
|
||||
line.split(",")[1]
|
||||
)
|
||||
)
|
||||
|
||||
stdoutLines[idx] = (
|
||||
f""
|
||||
)
|
||||
|
||||
output["stdout"] = "\n".join(stdoutLines)
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
|
||||
content_blocks[-1]["output"] = output
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": {
|
||||
"content": serialize_content_blocks(content_blocks),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
res = await generate_chat_completion(
|
||||
request,
|
||||
{
|
||||
"model": model_id,
|
||||
"stream": True,
|
||||
"messages": [
|
||||
*form_data["messages"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks, raw=True
|
||||
),
|
||||
},
|
||||
],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
if isinstance(res, StreamingResponse):
|
||||
await stream_body_handler(res)
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
break
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
data = {
|
||||
"done": True,
|
||||
|
||||
Reference in New Issue
Block a user