This commit is contained in:
Timothy Jaeryang Baek
2024-12-11 19:52:46 -08:00
parent 772f5ccd60
commit fe5519e0a2
5 changed files with 236 additions and 212 deletions

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel
from starlette.responses import FileResponse
from typing import Optional
import logging
@@ -16,6 +17,9 @@ from open_webui.utils.task import (
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS
from open_webui.routers.pipelines import process_pipeline_inlet_filter
from open_webui.utils.task import get_task_model_id
from open_webui.config import (
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
@@ -121,9 +125,7 @@ async def update_task_config(
async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -191,7 +193,7 @@ Artificial Intelligence in Healthcare
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
@@ -220,8 +222,7 @@ async def generate_chat_tags(
content={"detail": "Tags generation is disabled"},
)
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -281,7 +282,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
@@ -318,8 +319,7 @@ async def generate_queries(
detail=f"Query generation is disabled",
)
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -363,7 +363,7 @@ async def generate_queries(
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
@@ -405,8 +405,7 @@ async def generate_autocompletion(
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
)
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -450,7 +449,7 @@ async def generate_autocompletion(
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
@@ -473,8 +472,7 @@ async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -525,7 +523,7 @@ Message: """{{prompt}}"""
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
@@ -548,10 +546,9 @@ async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -593,7 +590,7 @@ Responses from models: {{responses}}"""
}
try:
payload = filter_pipeline(payload, user, models)
payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(