feat(sqlalchemy): Replace peewee with sqlalchemy

This commit is contained in:
Jonathan Rohde
2024-06-18 15:03:31 +02:00
parent 8dac2a2140
commit df09d0830a
47 changed files with 2580 additions and 1003 deletions

View File

@@ -1,5 +1,6 @@
import base64
import uuid
import subprocess
from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
@@ -27,6 +28,8 @@ from fastapi.responses import JSONResponse
from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from sqlalchemy.orm import Session
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
@@ -54,6 +57,7 @@ from apps.webui.main import (
get_pipe_models,
generate_function_chat_completion,
)
from apps.webui.internal.db import get_db, SessionLocal
from pydantic import BaseModel
@@ -124,6 +128,8 @@ from config import (
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
BACKEND_DIR,
DATABASE_URL,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook
@@ -166,8 +172,19 @@ https://github.com/open-webui/open-webui
)
def run_migrations():
from alembic.config import Config
from alembic import command
alembic_cfg = Config(f"{BACKEND_DIR}/alembic.ini")
alembic_cfg.set_main_option("sqlalchemy.url", DATABASE_URL)
alembic_cfg.set_main_option("script_location", f"{BACKEND_DIR}/migrations")
command.upgrade(alembic_cfg, "head")
@asynccontextmanager
async def lifespan(app: FastAPI):
run_migrations()
yield
@@ -393,6 +410,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
SessionLocal(),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
@@ -736,6 +754,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
SessionLocal(),
)
try:
@@ -781,7 +800,9 @@ app.add_middleware(
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
db = SessionLocal()
await get_all_models(db)
db.commit()
else:
pass
@@ -815,12 +836,12 @@ app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models():
async def get_all_models(db: Session):
pipe_models = []
openai_models = []
ollama_models = []
pipe_models = await get_pipe_models()
pipe_models = await get_pipe_models(db)
if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models()
@@ -842,7 +863,7 @@ async def get_all_models():
models = pipe_models + openai_models + ollama_models
custom_models = Models.get_all_models()
custom_models = Models.get_all_models(db)
for custom_model in custom_models:
if custom_model.base_model_id == None:
for model in models:
@@ -882,8 +903,8 @@ async def get_all_models():
@app.get("/api/models")
async def get_models(user=Depends(get_verified_user)):
models = await get_all_models()
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
models = await get_all_models(db)
# Filter out filter pipelines
models = [
@@ -1584,9 +1605,12 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models()
models = await get_all_models(db)
r = None
try:
@@ -1622,9 +1646,12 @@ async def get_pipeline_valves(
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models()
models = await get_all_models(db)
r = None
try:
@@ -1663,8 +1690,9 @@ async def update_pipeline_valves(
pipeline_id: str,
form_data: dict,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models()
models = await get_all_models(db)
r = None
try:
@@ -2011,6 +2039,12 @@ async def healthcheck():
return {"status": True}
@app.get("/health/db")
async def healthcheck_with_db(db: Session = Depends(get_db)):
result = db.execute(text("SELECT 1;")).all()
return {"status": True}
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")