mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
refac
This commit is contained in:
@@ -71,7 +71,10 @@ from open_webui.models.models import Models
|
||||
from open_webui.retrieval.utils import get_sources_from_items
|
||||
|
||||
|
||||
from open_webui.utils.chat import generate_chat_completion
|
||||
from open_webui.utils.chat import (
|
||||
generate_chat_completion,
|
||||
chat_completed,
|
||||
)
|
||||
from open_webui.utils.task import (
|
||||
get_task_model_id,
|
||||
rag_template,
|
||||
@@ -1079,11 +1082,17 @@ def apply_params_to_form_data(form_data, model):
|
||||
return form_data
|
||||
|
||||
|
||||
async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
async def process_chat_payload(request, form_data, user, metadata, extra_params):
|
||||
# Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
|
||||
# -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
|
||||
# -> Chat Files
|
||||
|
||||
event_emitter = extra_params.get("__event_emitter__", None)
|
||||
event_caller = extra_params.get("__event_call__", None)
|
||||
|
||||
oauth_token = extra_params.get("__oauth_token__", None)
|
||||
model = extra_params.get("__model__", None)
|
||||
|
||||
form_data = apply_params_to_form_data(form_data, model)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
|
||||
@@ -1096,29 +1105,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
except:
|
||||
pass
|
||||
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_call = get_event_call(metadata)
|
||||
|
||||
oauth_token = None
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||
user.id,
|
||||
request.cookies.get("oauth_session_id", None),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_call,
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
"__oauth_token__": oauth_token,
|
||||
}
|
||||
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
@@ -1529,7 +1515,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
|
||||
|
||||
async def process_chat_response(
|
||||
request, response, form_data, user, metadata, model, events, tasks
|
||||
request, response, form_data, user, metadata, extra_params, events, tasks
|
||||
):
|
||||
async def background_tasks_handler():
|
||||
message = None
|
||||
@@ -1752,18 +1738,9 @@ async def process_chat_response(
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
event_emitter = None
|
||||
event_caller = None
|
||||
if (
|
||||
"session_id" in metadata
|
||||
and metadata["session_id"]
|
||||
and "chat_id" in metadata
|
||||
and metadata["chat_id"]
|
||||
and "message_id" in metadata
|
||||
and metadata["message_id"]
|
||||
):
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_caller = get_event_call(metadata)
|
||||
model = extra_params.get("__model__", None)
|
||||
event_emitter = extra_params.get("__event_emitter__", None)
|
||||
event_caller = extra_params.get("__event_call__", None)
|
||||
|
||||
# Non-streaming response
|
||||
if not isinstance(response, StreamingResponse):
|
||||
@@ -1832,8 +1809,18 @@ async def process_chat_response(
|
||||
}
|
||||
)
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
@@ -1845,16 +1832,6 @@ async def process_chat_response(
|
||||
}
|
||||
)
|
||||
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not get_active_status_by_user_id(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
@@ -1923,32 +1900,12 @@ async def process_chat_response(
|
||||
):
|
||||
return response
|
||||
|
||||
oauth_token = None
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||
user.id,
|
||||
request.cookies.get("oauth_session_id", None),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_caller,
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__oauth_token__": oauth_token,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
}
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
||||
# Streaming response
|
||||
if event_emitter and event_caller:
|
||||
task_id = str(uuid4()) # Create a unique task ID.
|
||||
@@ -3163,12 +3120,35 @@ async def process_chat_response(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
completed_res = await chat_completed(
|
||||
request,
|
||||
{
|
||||
"id": metadata.get("message_id"),
|
||||
"chat_id": metadata.get("chat_id"),
|
||||
"session_id": metadata.get("session_id"),
|
||||
"filter_ids": metadata.get("filter_ids", []),
|
||||
|
||||
"model": form_data.get("model"),
|
||||
"messages": [*form_data.get("messages", []), response_message],
|
||||
},
|
||||
user,
|
||||
metadata,
|
||||
extra_params,
|
||||
)
|
||||
|
||||
if completed_res and completed_res.get("messages"):
|
||||
for message in completed_res["messages"]:
|
||||
|
||||
|
||||
|
||||
if response.background is not None:
|
||||
await response.background()
|
||||
|
||||
return await response_handler(response, events)
|
||||
|
||||
else:
|
||||
response_message = {}
|
||||
# Fallback to the original response
|
||||
async def stream_wrapper(original_generator, events):
|
||||
def wrap_item(item):
|
||||
@@ -3198,6 +3178,22 @@ async def process_chat_response(
|
||||
if data:
|
||||
yield data
|
||||
|
||||
await chat_completed(
|
||||
request,
|
||||
{
|
||||
"id": metadata.get("message_id"),
|
||||
"chat_id": metadata.get("chat_id"),
|
||||
"session_id": metadata.get("session_id"),
|
||||
"filter_ids": metadata.get("filter_ids", []),
|
||||
|
||||
"model": form_data.get("model"),
|
||||
"messages": [*form_data.get("messages", []), response_message],
|
||||
},
|
||||
user,
|
||||
metadata,
|
||||
extra_params,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator, events),
|
||||
headers=dict(response.headers),
|
||||
|
||||
Reference in New Issue
Block a user