mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-15 11:27:46 +01:00
refac: chat requests
This commit is contained in:
@@ -30,7 +30,9 @@ from fastapi import (
|
||||
UploadFile,
|
||||
status,
|
||||
applications,
|
||||
BackgroundTasks,
|
||||
)
|
||||
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -295,6 +297,7 @@ from open_webui.utils.auth import (
|
||||
from open_webui.utils.oauth import oauth_manager
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
@@ -822,11 +825,11 @@ async def chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user),
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
|
||||
tasks = form_data.pop("background_tasks", None)
|
||||
try:
|
||||
model_id = form_data.get("model", None)
|
||||
if model_id not in request.app.state.MODELS:
|
||||
@@ -834,13 +837,14 @@ async def chat_completion(
|
||||
model = request.app.state.MODELS[model_id]
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
metadata = {
|
||||
"user_id": user.id,
|
||||
"chat_id": form_data.pop("chat_id", None),
|
||||
"message_id": form_data.pop("id", None),
|
||||
"session_id": form_data.pop("session_id", None),
|
||||
@@ -859,10 +863,10 @@ async def chat_completion(
|
||||
)
|
||||
|
||||
try:
|
||||
response = await chat_completion_handler(
|
||||
request, form_data, user, bypass_filter
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
return await process_chat_response(
|
||||
request, response, user, events, metadata, tasks
|
||||
)
|
||||
return await process_chat_response(response, events, metadata)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -901,6 +905,20 @@ async def chat_action(
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/tasks/stop/{task_id}")
|
||||
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
|
||||
try:
|
||||
result = await stop_task(task_id) # Use the function from tasks.py
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/api/tasks")
|
||||
async def list_tasks_endpoint(user=Depends(get_verified_user)):
|
||||
return {"tasks": list_tasks()} # Use the function from tasks.py
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Config Endpoints
|
||||
|
||||
Reference in New Issue
Block a user