feat: __event_call__ support

This commit is contained in:
Timothy J. Baek
2024-07-08 21:39:06 -07:00
parent 1d979d9b75
commit 1b7ff1c5df
3 changed files with 74 additions and 11 deletions

View File

@@ -302,6 +302,7 @@ async def get_function_call_response(
user,
model,
__event_emitter__=None,
__event_call__=None,
):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
@@ -445,6 +446,13 @@ async def get_function_call_response(
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
# Call the function with the '__event_call__' parameter included
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
@@ -468,7 +476,9 @@ async def get_function_call_response(
return None, None, False
async def chat_completion_functions_handler(body, model, user, __event_emitter__):
async def chat_completion_functions_handler(
body, model, user, __event_emitter__, __event_call__
):
skip_files = None
filter_ids = get_filter_function_ids(model)
@@ -534,12 +544,19 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
@@ -556,7 +573,9 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
return body, {}
async def chat_completion_tools_handler(body, model, user, __event_emitter__):
async def chat_completion_tools_handler(
body, model, user, __event_emitter__, __event_call__
):
skip_files = None
contexts = []
@@ -579,6 +598,7 @@ async def chat_completion_tools_handler(body, model, user, __event_emitter__):
user=user,
model=model,
__event_emitter__=__event_emitter__,
__event_call__=__event_call__,
)
print(file_handler)
@@ -676,6 +696,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
to=session_id,
)
async def __event_call__(data):
response = await sio.call(
"chat-events",
{"chat_id": chat_id, "message_id": message_id, "data": data},
to=session_id,
)
return response
# Initialize data_items to store additional data to be sent to the client
data_items = []
@@ -685,7 +713,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try:
body, flags = await chat_completion_functions_handler(
body, model, user, __event_emitter__
body, model, user, __event_emitter__, __event_call__
)
except Exception as e:
return JSONResponse(
@@ -695,7 +723,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try:
body, flags = await chat_completion_tools_handler(
body, model, user, __event_emitter__
body, model, user, __event_emitter__, __event_call__
)
contexts.extend(flags.get("contexts", []))