Merge pull request #3327 from jonathan-rohde/feat/sqlalchemy-instead-of-peewee

BREAKING CHANGE/sqlalchemy instead of peewee
This commit is contained in:
Timothy Jaeryang Baek
2024-07-02 16:40:13 -07:00
committed by GitHub
60 changed files with 2217 additions and 2106 deletions

View File

@@ -1,5 +1,6 @@
import base64
import uuid
import subprocess
from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
@@ -27,6 +28,7 @@ from fastapi.responses import JSONResponse
from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
@@ -54,6 +56,7 @@ from apps.webui.main import (
get_pipe_models,
generate_function_chat_completion,
)
from apps.webui.internal.db import Session, SessionLocal
from pydantic import BaseModel
@@ -125,6 +128,8 @@ from config import (
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
BACKEND_DIR,
DATABASE_URL,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook
@@ -167,8 +172,19 @@ https://github.com/open-webui/open-webui
)
def run_migrations():
env = os.environ.copy()
env["DATABASE_URL"] = DATABASE_URL
migration_task = subprocess.run(
["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env
)
if migration_task.returncode > 0:
raise ValueError("Error running migrations")
@asynccontextmanager
async def lifespan(app: FastAPI):
run_migrations()
yield
@@ -902,6 +918,14 @@ app.add_middleware(
)
@app.middleware("http")
async def commit_session_after_request(request: Request, call_next):
response = await call_next(request)
log.debug("Commit session after request")
Session.commit()
return response
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
@@ -1743,7 +1767,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
):
models = await get_all_models()
r = None
@@ -1781,7 +1807,9 @@ async def get_pipeline_valves(
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
):
models = await get_all_models()
@@ -2168,6 +2196,12 @@ async def healthcheck():
return {"status": True}
@app.get("/health/db")
async def healthcheck_with_db():
Session.execute(text("SELECT 1;")).all()
return {"status": True}
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")