mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
enh: ENABLE_MODEL_LIST_CACHE
This commit is contained in:
@@ -36,7 +36,6 @@ from fastapi import (
|
||||
applications,
|
||||
BackgroundTasks,
|
||||
)
|
||||
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
|
||||
from open_webui.utils import logger
|
||||
@@ -116,6 +116,8 @@ from open_webui.config import (
|
||||
OPENAI_API_CONFIGS,
|
||||
# Direct Connections
|
||||
ENABLE_DIRECT_CONNECTIONS,
|
||||
# Model list
|
||||
ENABLE_MODEL_LIST_CACHE,
|
||||
# Thread pool size for FastAPI/AnyIO
|
||||
THREAD_POOL_SIZE,
|
||||
# Tool Server Configs
|
||||
@@ -534,6 +536,27 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
asyncio.create_task(periodic_usage_pool_cleanup())
|
||||
|
||||
if app.state.config.ENABLE_MODEL_LIST_CACHE:
|
||||
get_all_models(
|
||||
Request(
|
||||
# Creating a mock request object to pass to get_all_models
|
||||
{
|
||||
"type": "http",
|
||||
"asgi.version": "3.0",
|
||||
"asgi.spec_version": "2.0",
|
||||
"method": "GET",
|
||||
"path": "/internal",
|
||||
"query_string": b"",
|
||||
"headers": Headers({}).raw,
|
||||
"client": ("127.0.0.1", 12345),
|
||||
"server": ("127.0.0.1", 80),
|
||||
"scheme": "http",
|
||||
"app": app,
|
||||
}
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(app.state, "redis_task_command_listener"):
|
||||
@@ -616,6 +639,14 @@ app.state.TOOL_SERVERS = []
|
||||
|
||||
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
||||
|
||||
########################################
|
||||
#
|
||||
# MODEL LIST
|
||||
#
|
||||
########################################
|
||||
|
||||
app.state.config.ENABLE_MODEL_LIST_CACHE = ENABLE_MODEL_LIST_CACHE
|
||||
|
||||
########################################
|
||||
#
|
||||
# WEBUI
|
||||
@@ -1191,7 +1222,9 @@ if audit_level != AuditLevel.NONE:
|
||||
|
||||
|
||||
@app.get("/api/models")
|
||||
async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
async def get_models(
|
||||
request: Request, refresh: bool = False, user=Depends(get_verified_user)
|
||||
):
|
||||
def get_filtered_models(models, user):
|
||||
filtered_models = []
|
||||
for model in models:
|
||||
@@ -1215,7 +1248,12 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
return filtered_models
|
||||
|
||||
all_models = await get_all_models(request, user=user)
|
||||
if request.app.state.MODELS and (
|
||||
request.app.state.config.ENABLE_MODEL_LIST_CACHE and not refresh
|
||||
):
|
||||
all_models = list(request.app.state.MODELS.values())
|
||||
else:
|
||||
all_models = await get_all_models(request, user=user)
|
||||
|
||||
models = []
|
||||
for model in all_models:
|
||||
|
||||
Reference in New Issue
Block a user