feat: __event_emitter__

This commit is contained in:
Timothy J. Baek
2024-07-01 20:05:02 -07:00
parent e5895af7a0
commit a07051f51b
2 changed files with 54 additions and 9 deletions

View File

@@ -33,7 +33,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import app as socket_app
from apps.socket.main import sio, app as socket_app
from apps.ollama.main import (
app as ollama_app,
OpenAIChatCompletionForm,
@@ -277,7 +277,14 @@ def get_filter_function_ids(model):
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user, model
messages,
files,
tool_id,
template,
task_model_id,
user,
model,
__event_emitter__=None,
):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
@@ -414,6 +421,13 @@ async def get_function_call_response(
"__id__": tool_id,
}
if "__event_emitter__" in sig.parameters:
# Call the function with the '__event_emitter__' parameter included
params = {
**params,
"__event_emitter__": model,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
@@ -437,7 +451,7 @@ async def get_function_call_response(
return None, None, False
async def chat_completion_functions_handler(body, model, user):
async def chat_completion_functions_handler(body, model, user, __event_emitter__):
skip_files = None
filter_ids = get_filter_function_ids(model)
@@ -503,6 +517,11 @@ async def chat_completion_functions_handler(body, model, user):
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
@@ -520,7 +539,7 @@ async def chat_completion_functions_handler(body, model, user):
return body, {}
async def chat_completion_tools_handler(body, model, user):
async def chat_completion_tools_handler(body, model, user, __event_emitter__):
skip_files = None
contexts = []
@@ -542,6 +561,7 @@ async def chat_completion_tools_handler(body, model, user):
task_model_id=task_model_id,
user=user,
model=model,
__event_emitter__=__event_emitter__,
)
print(file_handler)
@@ -614,7 +634,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content={"detail": str(e)},
)
# Extract chat_id and message_id from the request body
# Extract session_id, chat_id and message_id from the request body
session_id = None
if "session_id" in body:
session_id = body["session_id"]
del body["session_id"]
chat_id = None
if "chat_id" in body:
chat_id = body["chat_id"]
@@ -624,6 +648,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
message_id = body["id"]
del body["id"]
async def __event_emitter__(data):
await sio.emit(
"chat-events",
{
"chat_id": chat_id,
"message_id": message_id,
"data": data,
},
to=session_id,
)
# Initialize data_items to store additional data to be sent to the client
data_items = []
@@ -631,10 +666,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
contexts = []
citations = []
print(body)
try:
body, flags = await chat_completion_functions_handler(body, model, user)
body, flags = await chat_completion_functions_handler(
body, model, user, __event_emitter__
)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -642,7 +677,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
try:
body, flags = await chat_completion_tools_handler(body, model, user)
body, flags = await chat_completion_tools_handler(
body, model, user, __event_emitter__
)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))