mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 20:07:49 +01:00
feat(sqlalchemy): Replace peewee with sqlalchemy
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import uuid
|
||||
import subprocess
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
@@ -27,6 +28,8 @@ from fastapi.responses import JSONResponse
|
||||
from fastapi import HTTPException
|
||||
from fastapi.middleware.wsgi import WSGIMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
@@ -54,6 +57,7 @@ from apps.webui.main import (
|
||||
get_pipe_models,
|
||||
generate_function_chat_completion,
|
||||
)
|
||||
from apps.webui.internal.db import get_db, SessionLocal
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -124,6 +128,8 @@ from config import (
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
WEBUI_SESSION_COOKIE_SECURE,
|
||||
AppConfig,
|
||||
BACKEND_DIR,
|
||||
DATABASE_URL,
|
||||
)
|
||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from utils.webhook import post_webhook
|
||||
@@ -166,8 +172,19 @@ https://github.com/open-webui/open-webui
|
||||
)
|
||||
|
||||
|
||||
def run_migrations():
|
||||
from alembic.config import Config
|
||||
from alembic import command
|
||||
|
||||
alembic_cfg = Config(f"{BACKEND_DIR}/alembic.ini")
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", DATABASE_URL)
|
||||
alembic_cfg.set_main_option("script_location", f"{BACKEND_DIR}/migrations")
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
run_migrations()
|
||||
yield
|
||||
|
||||
|
||||
@@ -393,6 +410,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||
SessionLocal(),
|
||||
)
|
||||
# Flag to skip RAG completions if file_handler is present in tools/functions
|
||||
skip_files = False
|
||||
@@ -736,6 +754,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||
SessionLocal(),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -781,7 +800,9 @@ app.add_middleware(
|
||||
@app.middleware("http")
|
||||
async def check_url(request: Request, call_next):
|
||||
if len(app.state.MODELS) == 0:
|
||||
await get_all_models()
|
||||
db = SessionLocal()
|
||||
await get_all_models(db)
|
||||
db.commit()
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -815,12 +836,12 @@ app.mount("/api/v1", webui_app)
|
||||
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
|
||||
|
||||
|
||||
async def get_all_models():
|
||||
async def get_all_models(db: Session):
|
||||
pipe_models = []
|
||||
openai_models = []
|
||||
ollama_models = []
|
||||
|
||||
pipe_models = await get_pipe_models()
|
||||
pipe_models = await get_pipe_models(db)
|
||||
|
||||
if app.state.config.ENABLE_OPENAI_API:
|
||||
openai_models = await get_openai_models()
|
||||
@@ -842,7 +863,7 @@ async def get_all_models():
|
||||
|
||||
models = pipe_models + openai_models + ollama_models
|
||||
|
||||
custom_models = Models.get_all_models()
|
||||
custom_models = Models.get_all_models(db)
|
||||
for custom_model in custom_models:
|
||||
if custom_model.base_model_id == None:
|
||||
for model in models:
|
||||
@@ -882,8 +903,8 @@ async def get_all_models():
|
||||
|
||||
|
||||
@app.get("/api/models")
|
||||
async def get_models(user=Depends(get_verified_user)):
|
||||
models = await get_all_models()
|
||||
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
|
||||
models = await get_all_models(db)
|
||||
|
||||
# Filter out filter pipelines
|
||||
models = [
|
||||
@@ -1584,9 +1605,12 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
|
||||
|
||||
@app.get("/api/pipelines/{pipeline_id}/valves")
|
||||
async def get_pipeline_valves(
|
||||
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
|
||||
urlIdx: Optional[int],
|
||||
pipeline_id: str,
|
||||
user=Depends(get_admin_user),
|
||||
db=Depends(get_db),
|
||||
):
|
||||
models = await get_all_models()
|
||||
models = await get_all_models(db)
|
||||
r = None
|
||||
try:
|
||||
|
||||
@@ -1622,9 +1646,12 @@ async def get_pipeline_valves(
|
||||
|
||||
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
|
||||
async def get_pipeline_valves_spec(
|
||||
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
|
||||
urlIdx: Optional[int],
|
||||
pipeline_id: str,
|
||||
user=Depends(get_admin_user),
|
||||
db=Depends(get_db),
|
||||
):
|
||||
models = await get_all_models()
|
||||
models = await get_all_models(db)
|
||||
|
||||
r = None
|
||||
try:
|
||||
@@ -1663,8 +1690,9 @@ async def update_pipeline_valves(
|
||||
pipeline_id: str,
|
||||
form_data: dict,
|
||||
user=Depends(get_admin_user),
|
||||
db=Depends(get_db),
|
||||
):
|
||||
models = await get_all_models()
|
||||
models = await get_all_models(db)
|
||||
|
||||
r = None
|
||||
try:
|
||||
@@ -2011,6 +2039,12 @@ async def healthcheck():
|
||||
return {"status": True}
|
||||
|
||||
|
||||
@app.get("/health/db")
|
||||
async def healthcheck_with_db(db: Session = Depends(get_db)):
|
||||
result = db.execute(text("SELECT 1;")).all()
|
||||
return {"status": True}
|
||||
|
||||
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user