enh: ENABLE_MODEL_LIST_CACHE

This commit is contained in:
Timothy Jaeryang Baek
2025-06-28 15:12:31 +04:00
parent 2b88f66762
commit 1a52585769
5 changed files with 105 additions and 27 deletions

View File

@@ -36,7 +36,6 @@ from fastapi import (
applications,
BackgroundTasks,
)
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
@@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
from starlette.datastructures import Headers
from open_webui.utils import logger
@@ -116,6 +116,8 @@ from open_webui.config import (
OPENAI_API_CONFIGS,
# Direct Connections
ENABLE_DIRECT_CONNECTIONS,
# Model list
ENABLE_MODEL_LIST_CACHE,
# Thread pool size for FastAPI/AnyIO
THREAD_POOL_SIZE,
# Tool Server Configs
@@ -534,6 +536,27 @@ async def lifespan(app: FastAPI):
asyncio.create_task(periodic_usage_pool_cleanup())
if app.state.config.ENABLE_MODEL_LIST_CACHE:
get_all_models(
Request(
# Creating a mock request object to pass to get_all_models
{
"type": "http",
"asgi.version": "3.0",
"asgi.spec_version": "2.0",
"method": "GET",
"path": "/internal",
"query_string": b"",
"headers": Headers({}).raw,
"client": ("127.0.0.1", 12345),
"server": ("127.0.0.1", 80),
"scheme": "http",
"app": app,
}
),
None,
)
yield
if hasattr(app.state, "redis_task_command_listener"):
@@ -616,6 +639,14 @@ app.state.TOOL_SERVERS = []
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
########################################
#
# MODEL LIST
#
########################################
app.state.config.ENABLE_MODEL_LIST_CACHE = ENABLE_MODEL_LIST_CACHE
########################################
#
# WEBUI
@@ -1191,7 +1222,9 @@ if audit_level != AuditLevel.NONE:
@app.get("/api/models")
async def get_models(request: Request, user=Depends(get_verified_user)):
async def get_models(
request: Request, refresh: bool = False, user=Depends(get_verified_user)
):
def get_filtered_models(models, user):
filtered_models = []
for model in models:
@@ -1215,7 +1248,12 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models
all_models = await get_all_models(request, user=user)
if request.app.state.MODELS and (
request.app.state.config.ENABLE_MODEL_LIST_CACHE and not refresh
):
all_models = list(request.app.state.MODELS.values())
else:
all_models = await get_all_models(request, user=user)
models = []
for model in all_models: