refac: chat requests

This commit is contained in:
Timothy Jaeryang Baek
2024-12-19 01:00:32 -08:00
parent ea0d507e23
commit 2be9e55545
11 changed files with 752 additions and 424 deletions

View File

@@ -30,7 +30,9 @@ from fastapi import (
UploadFile,
status,
applications,
BackgroundTasks,
)
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
@@ -295,6 +297,7 @@ from open_webui.utils.auth import (
from open_webui.utils.oauth import oauth_manager
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
if SAFE_MODE:
print("SAFE MODE ENABLED")
@@ -822,11 +825,11 @@ async def chat_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: bool = False,
):
if not request.app.state.MODELS:
await get_all_models(request)
tasks = form_data.pop("background_tasks", None)
try:
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
@@ -834,13 +837,14 @@ async def chat_completion(
model = request.app.state.MODELS[model_id]
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
metadata = {
"user_id": user.id,
"chat_id": form_data.pop("chat_id", None),
"message_id": form_data.pop("id", None),
"session_id": form_data.pop("session_id", None),
@@ -859,10 +863,10 @@ async def chat_completion(
)
try:
response = await chat_completion_handler(
request, form_data, user, bypass_filter
response = await chat_completion_handler(request, form_data, user)
return await process_chat_response(
request, response, user, events, metadata, tasks
)
return await process_chat_response(response, events, metadata)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -901,6 +905,20 @@ async def chat_action(
)
@app.post("/api/tasks/stop/{task_id}")
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
try:
result = await stop_task(task_id) # Use the function from tasks.py
return result
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@app.get("/api/tasks")
async def list_tasks_endpoint(user=Depends(get_verified_user)):
return {"tasks": list_tasks()} # Use the function from tasks.py
##################################
#
# Config Endpoints