feat: tools full integration

This commit is contained in:
Timothy J. Baek
2024-06-11 00:18:45 -07:00
parent a27175d672
commit 3d6f5f418d
4 changed files with 75 additions and 39 deletions

View File

@@ -185,39 +185,48 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
model = app.state.MODELS[task_model_id]
response = None
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
try:
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
print(response)
content = response["choices"][0]["message"]["content"]
content = None
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Parse the function response
if content != "":
result = json.loads(content)
print(result)
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
# Parse the function response
if content is not None:
result = json.loads(content)
print(result)
function = getattr(toolkit_module, result["name"])
function_result = None
try:
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
# Add the function result to the system prompt
if function_result:
return function_result
function = getattr(toolkit_module, result["name"])
function_result = None
try:
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result:
return function_result
except Exception as e:
print(f"Error: {e}")
return None
@@ -285,15 +294,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
print(response)
if response:
context += f"\n{response}"
context = ("\n" if context != "" else "") + response
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
data["messages"] = add_or_update_system_message(
system_prompt, data["messages"]
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
del data["tool_ids"]