feat(sqlalchemy): remove session reference from router

This commit is contained in:
Jonathan Rohde
2024-06-21 14:58:57 +02:00
parent df09d0830a
commit bee835cb65
34 changed files with 1231 additions and 1211 deletions

View File

@@ -57,7 +57,7 @@ from apps.webui.main import (
get_pipe_models,
generate_function_chat_completion,
)
from apps.webui.internal.db import get_db, SessionLocal
from apps.webui.internal.db import get_session, SessionLocal
from pydantic import BaseModel
@@ -410,7 +410,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
SessionLocal(),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
@@ -800,9 +799,7 @@ app.add_middleware(
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
db = SessionLocal()
await get_all_models(db)
db.commit()
await get_all_models()
else:
pass
@@ -836,12 +833,12 @@ app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models(db: Session):
async def get_all_models():
pipe_models = []
openai_models = []
ollama_models = []
pipe_models = await get_pipe_models(db)
pipe_models = await get_pipe_models()
if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models()
@@ -863,7 +860,7 @@ async def get_all_models(db: Session):
models = pipe_models + openai_models + ollama_models
custom_models = Models.get_all_models(db)
custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id == None:
for model in models:
@@ -903,8 +900,8 @@ async def get_all_models(db: Session):
@app.get("/api/models")
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
models = await get_all_models(db)
async def get_models(user=Depends(get_verified_user)):
models = await get_all_models()
# Filter out filter pipelines
models = [
@@ -1608,9 +1605,8 @@ async def get_pipeline_valves(
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models(db)
models = await get_all_models()
r = None
try:
@@ -1649,9 +1645,8 @@ async def get_pipeline_valves_spec(
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models(db)
models = await get_all_models()
r = None
try:
@@ -1690,9 +1685,8 @@ async def update_pipeline_valves(
pipeline_id: str,
form_data: dict,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models(db)
models = await get_all_models()
r = None
try:
@@ -2040,8 +2034,9 @@ async def healthcheck():
@app.get("/health/db")
async def healthcheck_with_db(db: Session = Depends(get_db)):
result = db.execute(text("SELECT 1;")).all()
async def healthcheck_with_db():
with get_session() as db:
result = db.execute(text("SELECT 1;")).all()
return {"status": True}