mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
refac
This commit is contained in:
@@ -64,6 +64,7 @@ from open_webui.socket.main import (
|
||||
app as socket_app,
|
||||
periodic_usage_pool_cleanup,
|
||||
get_event_emitter,
|
||||
get_event_call,
|
||||
get_models_in_use,
|
||||
get_active_user_ids,
|
||||
)
|
||||
@@ -481,7 +482,6 @@ from open_webui.utils.models import (
|
||||
)
|
||||
from open_webui.utils.chat import (
|
||||
generate_chat_completion as chat_completion_handler,
|
||||
chat_completed as chat_completed_handler,
|
||||
chat_action as chat_action_handler,
|
||||
)
|
||||
from open_webui.utils.embeddings import generate_embeddings
|
||||
@@ -1566,10 +1566,40 @@ async def chat_completion(
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
async def process_chat(request, form_data, user, metadata, model):
|
||||
try:
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_call = get_event_call(metadata)
|
||||
|
||||
oauth_token = None
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||
user.id,
|
||||
request.cookies.get("oauth_session_id", None),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_call,
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
"__oauth_token__": oauth_token,
|
||||
}
|
||||
except Exception as e:
|
||||
log.debug(f"Error setting up extra params: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
async def process_chat(request, form_data, user, metadata, extra_params):
|
||||
try:
|
||||
form_data, metadata, events = await process_chat_payload(
|
||||
request, form_data, user, metadata, model
|
||||
request, form_data, user, metadata, extra_params
|
||||
)
|
||||
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
@@ -1587,7 +1617,14 @@ async def chat_completion(
|
||||
pass
|
||||
|
||||
return await process_chat_response(
|
||||
request, response, form_data, user, metadata, model, events, tasks
|
||||
request,
|
||||
response,
|
||||
form_data,
|
||||
user,
|
||||
metadata,
|
||||
extra_params,
|
||||
events,
|
||||
tasks,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
log.info("Chat processing was cancelled")
|
||||
@@ -1646,12 +1683,12 @@ async def chat_completion(
|
||||
# Asynchronous Chat Processing
|
||||
task_id, _ = await create_task(
|
||||
request.app.state.redis,
|
||||
process_chat(request, form_data, user, metadata, model),
|
||||
process_chat(request, form_data, user, metadata, extra_params),
|
||||
id=metadata["chat_id"],
|
||||
)
|
||||
return {"status": True, "task_id": task_id}
|
||||
else:
|
||||
return await process_chat(request, form_data, user, metadata, model)
|
||||
return await process_chat(request, form_data, user, metadata, extra_params)
|
||||
|
||||
|
||||
# Alias for chat_completion (Legacy)
|
||||
@@ -1659,25 +1696,6 @@ generate_chat_completions = chat_completion
|
||||
generate_chat_completion = chat_completion
|
||||
|
||||
|
||||
@app.post("/api/chat/completed")
|
||||
async def chat_completed(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
model_item = form_data.pop("model_item", {})
|
||||
|
||||
if model_item.get("direct", False):
|
||||
request.state.direct = True
|
||||
request.state.model = model_item
|
||||
|
||||
return await chat_completed_handler(request, form_data, user)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/chat/actions/{action_id}")
|
||||
async def chat_action(
|
||||
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
|
||||
Reference in New Issue
Block a user