mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
Merge remote-tracking branch 'upstream/dev' into feat/oauth
This commit is contained in:
492
backend/main.py
492
backend/main.py
@@ -16,6 +16,7 @@ import mimetypes
|
||||
|
||||
from fastapi import FastAPI, Request, Depends, status
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import HTTPException
|
||||
from fastapi.middleware.wsgi import WSGIMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -24,6 +25,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.responses import StreamingResponse, Response, RedirectResponse
|
||||
|
||||
|
||||
from apps.socket.main import app as socket_app
|
||||
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
|
||||
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
|
||||
|
||||
@@ -43,6 +46,8 @@ from utils.misc import parse_duration
|
||||
from utils.utils import (
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
get_current_user,
|
||||
get_http_authorization_cred,
|
||||
get_password_hash,
|
||||
create_token,
|
||||
)
|
||||
@@ -136,7 +141,6 @@ app.state.MODELS = {}
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
|
||||
# Custom middleware to add security headers
|
||||
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
# async def dispatch(self, request: Request, call_next):
|
||||
@@ -154,7 +158,8 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
return_citations = False
|
||||
|
||||
if request.method == "POST" and (
|
||||
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
||||
"/ollama/api/chat" in request.url.path
|
||||
or "/chat/completions" in request.url.path
|
||||
):
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
@@ -239,6 +244,124 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
app.add_middleware(RAGMiddleware)
|
||||
|
||||
|
||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.method == "POST" and (
|
||||
"/ollama/api/chat" in request.url.path
|
||||
or "/chat/completions" in request.url.path
|
||||
):
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
# Decode body to string
|
||||
body_str = body.decode("utf-8")
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
model_id = data["model"]
|
||||
filters = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
and (
|
||||
model["pipeline"]["pipelines"] == ["*"]
|
||||
or any(
|
||||
model_id == target_model_id
|
||||
for target_model_id in model["pipeline"]["pipelines"]
|
||||
)
|
||||
)
|
||||
]
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
|
||||
user = None
|
||||
if len(sorted_filters) > 0:
|
||||
try:
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(
|
||||
request.headers.get("Authorization")
|
||||
)
|
||||
)
|
||||
user = {"id": user.id, "name": user.name, "role": user.role}
|
||||
except:
|
||||
pass
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
return JSONResponse(
|
||||
status_code=r.status_code,
|
||||
content=res,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
if "pipeline" not in app.state.MODELS[model_id]:
|
||||
if "chat_id" in data:
|
||||
del data["chat_id"]
|
||||
|
||||
if "title" in data:
|
||||
del data["title"]
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
request.headers.__dict__["_list"] = [
|
||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||
*[
|
||||
(k, v)
|
||||
for k, v in request.headers.raw
|
||||
if k.lower() != b"content-length"
|
||||
],
|
||||
]
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
async def _receive(self, body: bytes):
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
|
||||
app.add_middleware(PipelineMiddleware)
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
@@ -271,6 +394,9 @@ async def update_embedding_function(request: Request, call_next):
|
||||
return response
|
||||
|
||||
|
||||
app.mount("/ws", socket_app)
|
||||
|
||||
|
||||
app.mount("/ollama", ollama_app)
|
||||
app.mount("/openai", openai_app)
|
||||
|
||||
@@ -351,6 +477,14 @@ async def get_all_models():
|
||||
@app.get("/api/models")
|
||||
async def get_models(user=Depends(get_verified_user)):
|
||||
models = await get_all_models()
|
||||
|
||||
# Filter out filter pipelines
|
||||
models = [
|
||||
model
|
||||
for model in models
|
||||
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
|
||||
]
|
||||
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models = list(
|
||||
@@ -364,6 +498,339 @@ async def get_models(user=Depends(get_verified_user)):
|
||||
return {"data": models}
|
||||
|
||||
|
||||
@app.post("/api/chat/completed")
|
||||
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
|
||||
filters = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
and (
|
||||
model["pipeline"]["pipelines"] == ["*"]
|
||||
or any(
|
||||
model_id == target_model_id
|
||||
for target_model_id in model["pipeline"]["pipelines"]
|
||||
)
|
||||
)
|
||||
]
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
|
||||
print(model_id)
|
||||
|
||||
if model_id in app.state.MODELS:
|
||||
model = app.state.MODELS[model_id]
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": {"id": user.id, "name": user.name, "role": user.role},
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
return JSONResponse(
|
||||
status_code=r.status_code,
|
||||
content=res,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@app.get("/api/pipelines/list")
|
||||
async def get_pipelines_list(user=Depends(get_admin_user)):
|
||||
responses = await get_openai_models(raw=True)
|
||||
|
||||
print(responses)
|
||||
urlIdxs = [
|
||||
idx
|
||||
for idx, response in enumerate(responses)
|
||||
if response != None and "pipelines" in response
|
||||
]
|
||||
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx],
|
||||
"idx": urlIdx,
|
||||
}
|
||||
for urlIdx in urlIdxs
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class AddPipelineForm(BaseModel):
|
||||
url: str
|
||||
urlIdx: int
|
||||
|
||||
|
||||
@app.post("/api/pipelines/add")
|
||||
async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
|
||||
|
||||
r = None
|
||||
try:
|
||||
urlIdx = form_data.urlIdx
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/pipelines/add", headers=headers, json={"url": form_data.url}
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
detail = "Pipeline not found"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
class DeletePipelineForm(BaseModel):
|
||||
id: str
|
||||
urlIdx: int
|
||||
|
||||
|
||||
@app.delete("/api/pipelines/delete")
|
||||
async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
|
||||
|
||||
r = None
|
||||
try:
|
||||
urlIdx = form_data.urlIdx
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.delete(
|
||||
f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id}
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
detail = "Pipeline not found"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/pipelines")
|
||||
async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
|
||||
r = None
|
||||
try:
|
||||
urlIdx
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.get(f"{url}/pipelines", headers=headers)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
detail = "Pipeline not found"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/pipelines/{pipeline_id}/valves")
|
||||
async def get_pipeline_valves(
|
||||
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
models = await get_all_models()
|
||||
r = None
|
||||
try:
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
detail = "Pipeline not found"
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
|
||||
async def get_pipeline_valves_spec(
|
||||
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
models = await get_all_models()
|
||||
|
||||
r = None
|
||||
try:
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
detail = "Pipeline not found"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/pipelines/{pipeline_id}/valves/update")
|
||||
async def update_pipeline_valves(
|
||||
urlIdx: Optional[int],
|
||||
pipeline_id: str,
|
||||
form_data: dict,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
models = await get_all_models()
|
||||
|
||||
r = None
|
||||
try:
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{pipeline_id}/valves/update",
|
||||
headers=headers,
|
||||
json={**form_data},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
detail = "Pipeline not found"
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/config")
|
||||
async def get_app_config():
|
||||
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
|
||||
@@ -384,9 +851,10 @@ async def get_app_config():
|
||||
"auth": WEBUI_AUTH,
|
||||
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
||||
"enable_signup": webui_app.state.config.ENABLE_SIGNUP,
|
||||
"enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"enable_image_generation": images_app.state.config.ENABLED,
|
||||
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
||||
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
||||
},
|
||||
"oauth": {
|
||||
"providers": {
|
||||
@@ -438,23 +906,7 @@ class UrlForm(BaseModel):
|
||||
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
|
||||
app.state.config.WEBHOOK_URL = form_data.url
|
||||
webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
|
||||
|
||||
return {
|
||||
"url": app.state.config.WEBHOOK_URL,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/community_sharing", response_model=bool)
|
||||
async def get_community_sharing_status(request: Request, user=Depends(get_admin_user)):
|
||||
return webui_app.state.config.ENABLE_COMMUNITY_SHARING
|
||||
|
||||
|
||||
@app.get("/api/community_sharing/toggle", response_model=bool)
|
||||
async def toggle_community_sharing(request: Request, user=Depends(get_admin_user)):
|
||||
webui_app.state.config.ENABLE_COMMUNITY_SHARING = (
|
||||
not webui_app.state.config.ENABLE_COMMUNITY_SHARING
|
||||
)
|
||||
return webui_app.state.config.ENABLE_COMMUNITY_SHARING
|
||||
return {"url": app.state.config.WEBHOOK_URL}
|
||||
|
||||
|
||||
@app.get("/api/version")
|
||||
|
||||
Reference in New Issue
Block a user