This commit is contained in:
Timothy J. Baek
2024-07-11 13:43:44 -07:00
parent 9ab97b834a
commit f462744fc8
4 changed files with 33 additions and 12 deletions

View File

@@ -618,6 +618,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content={"detail": str(e)},
)
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
task = None
if "task" in body:
task = body["task"]
del body["task"]
# Extract session_id, chat_id and message_id from the request body
session_id = None
if "session_id" in body:
@@ -632,6 +638,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
message_id = body["id"]
del body["id"]
__event_emitter__ = await get_event_emitter(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
)
@@ -691,6 +699,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if len(citations) > 0:
data_items.append({"citations": citations})
body["metadata"] = {
"session_id": session_id,
"chat_id": chat_id,
"message_id": message_id,
"task": task,
}
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
@@ -811,9 +826,6 @@ def filter_pipeline(payload, user):
if "detail" in res:
raise Exception(r.status_code, res["detail"])
if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
del payload["task"]
return payload
@@ -1024,11 +1036,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
pipe = model.get("pipe")
if pipe:
if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)