This commit is contained in:
Timothy J. Baek
2024-06-18 16:08:42 -07:00
parent bcc27e3852
commit 514c7f1520
3 changed files with 30 additions and 26 deletions

View File

@@ -170,7 +170,9 @@ app.state.MODELS = {}
origins = ["*"]
async def get_function_call_response(messages, tool_id, template, task_model_id, user):
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user
):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
@@ -265,6 +267,13 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
"__messages__": messages,
}
if "__files__" in sig.parameters:
# Call the function with the '__files__' parameter included
params = {
**params,
"__files__": files,
}
function_result = function(**params)
except Exception as e:
print(e)
@@ -338,6 +347,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try:
response = await get_function_call_response(
messages=data["messages"],
files=data.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
@@ -353,7 +363,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
print(f"tool_context: {context}")
# If docs field is present, generate RAG completions
# TODO: Check if tools & functions have files support to skip this step to delegate file processing
# If files field is present, generate RAG completions
if "files" in data:
data = {**data}
rag_context, citations = get_rag_context(
@@ -376,15 +387,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
@@ -961,7 +969,12 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
try:
context = await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user
form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],
template,
model_id,
user,
)
return context
except Exception as e: