mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
Merge branch 'open-webui:main' into main
This commit is contained in:
@@ -21,7 +21,7 @@ from open_webui.env import (
|
||||
WEBUI_NAME,
|
||||
log,
|
||||
DATABASE_URL,
|
||||
OFFLINE_MODE
|
||||
OFFLINE_MODE,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
||||
@@ -272,6 +272,18 @@ ENABLE_API_KEY = PersistentConfig(
|
||||
os.environ.get("ENABLE_API_KEY", "True").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = PersistentConfig(
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS",
|
||||
"auth.api_key.endpoint_restrictions",
|
||||
os.environ.get("ENABLE_API_KEY_ENDPOINT_RESTRICTIONS", "False").lower() == "true",
|
||||
)
|
||||
|
||||
API_KEY_ALLOWED_ENDPOINTS = PersistentConfig(
|
||||
"API_KEY_ALLOWED_ENDPOINTS",
|
||||
"auth.api_key.allowed_endpoints",
|
||||
os.environ.get("API_KEY_ALLOWED_ENDPOINTS", ""),
|
||||
)
|
||||
|
||||
|
||||
JWT_EXPIRES_IN = PersistentConfig(
|
||||
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
|
||||
@@ -307,6 +319,7 @@ GOOGLE_CLIENT_SECRET = PersistentConfig(
|
||||
os.environ.get("GOOGLE_CLIENT_SECRET", ""),
|
||||
)
|
||||
|
||||
|
||||
GOOGLE_OAUTH_SCOPE = PersistentConfig(
|
||||
"GOOGLE_OAUTH_SCOPE",
|
||||
"oauth.google.scope",
|
||||
@@ -403,12 +416,24 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
|
||||
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
|
||||
)
|
||||
|
||||
OAUTH_GROUPS_CLAIM = PersistentConfig(
|
||||
"OAUTH_GROUPS_CLAIM",
|
||||
"oauth.oidc.group_claim",
|
||||
os.environ.get("OAUTH_GROUP_CLAIM", "groups"),
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
|
||||
"ENABLE_OAUTH_ROLE_MANAGEMENT",
|
||||
"oauth.enable_role_mapping",
|
||||
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig(
|
||||
"ENABLE_OAUTH_GROUP_MANAGEMENT",
|
||||
"oauth.enable_group_mapping",
|
||||
os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
OAUTH_ROLES_CLAIM = PersistentConfig(
|
||||
"OAUTH_ROLES_CLAIM",
|
||||
"oauth.roles_claim",
|
||||
@@ -696,6 +721,12 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
||||
# WEBUI
|
||||
####################################
|
||||
|
||||
|
||||
WEBUI_URL = PersistentConfig(
|
||||
"WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000")
|
||||
)
|
||||
|
||||
|
||||
ENABLE_SIGNUP = PersistentConfig(
|
||||
"ENABLE_SIGNUP",
|
||||
"ui.enable_signup",
|
||||
@@ -823,6 +854,12 @@ USER_PERMISSIONS = PersistentConfig(
|
||||
},
|
||||
)
|
||||
|
||||
ENABLE_CHANNELS = PersistentConfig(
|
||||
"ENABLE_CHANNELS",
|
||||
"channels.enable",
|
||||
os.environ.get("ENABLE_CHANNELS", "False").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig(
|
||||
"ENABLE_EVALUATION_ARENA_MODELS",
|
||||
@@ -1174,11 +1211,34 @@ if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
|
||||
raise ValueError(
|
||||
"Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database."
|
||||
)
|
||||
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
|
||||
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
||||
)
|
||||
|
||||
####################################
|
||||
# Information Retrieval (RAG)
|
||||
####################################
|
||||
|
||||
|
||||
# If configured, Google Drive will be available as an upload option.
|
||||
ENABLE_GOOGLE_DRIVE_INTEGRATION = PersistentConfig(
|
||||
"ENABLE_GOOGLE_DRIVE_INTEGRATION",
|
||||
"google_drive.enable",
|
||||
os.getenv("ENABLE_GOOGLE_DRIVE_INTEGRATION", "False").lower() == "true",
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_CLIENT_ID = PersistentConfig(
|
||||
"GOOGLE_DRIVE_CLIENT_ID",
|
||||
"google_drive.client_id",
|
||||
os.environ.get("GOOGLE_DRIVE_CLIENT_ID", ""),
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_API_KEY = PersistentConfig(
|
||||
"GOOGLE_DRIVE_API_KEY",
|
||||
"google_drive.api_key",
|
||||
os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
|
||||
)
|
||||
|
||||
# RAG Content Extraction
|
||||
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
||||
"CONTENT_EXTRACTION_ENGINE",
|
||||
@@ -1253,7 +1313,8 @@ RAG_EMBEDDING_MODEL = PersistentConfig(
|
||||
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
|
||||
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
||||
not OFFLINE_MODE and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
||||
not OFFLINE_MODE
|
||||
and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
||||
@@ -1278,7 +1339,8 @@ if RAG_RERANKING_MODEL.value != "":
|
||||
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
|
||||
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
||||
not OFFLINE_MODE and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
||||
not OFFLINE_MODE
|
||||
and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
||||
@@ -1412,6 +1474,7 @@ RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SEARXNG_QUERY_URL = PersistentConfig(
|
||||
"SEARXNG_QUERY_URL",
|
||||
"rag.web.search.searxng_query_url",
|
||||
@@ -1587,6 +1650,12 @@ COMFYUI_BASE_URL = PersistentConfig(
|
||||
os.getenv("COMFYUI_BASE_URL", ""),
|
||||
)
|
||||
|
||||
COMFYUI_API_KEY = PersistentConfig(
|
||||
"COMFYUI_API_KEY",
|
||||
"image_generation.comfyui.api_key",
|
||||
os.getenv("COMFYUI_API_KEY", ""),
|
||||
)
|
||||
|
||||
COMFYUI_DEFAULT_WORKFLOW = """
|
||||
{
|
||||
"3": {
|
||||
@@ -1748,7 +1817,8 @@ WHISPER_MODEL = PersistentConfig(
|
||||
|
||||
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
|
||||
WHISPER_MODEL_AUTO_UPDATE = (
|
||||
not OFFLINE_MODE and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
not OFFLINE_MODE
|
||||
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -53,6 +53,11 @@ if USE_CUDA.lower() == "true":
|
||||
else:
|
||||
DEVICE_TYPE = "cpu"
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
DEVICE_TYPE = "mps"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
####################################
|
||||
# LOGGING
|
||||
@@ -103,8 +108,6 @@ WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
||||
if WEBUI_NAME != "Open WebUI":
|
||||
WEBUI_NAME += " (Open WebUI)"
|
||||
|
||||
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
|
||||
|
||||
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
||||
|
||||
|
||||
@@ -315,6 +318,11 @@ RESET_CONFIG_ON_START = (
|
||||
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
ENABLE_REALTIME_CHAT_SAVE = (
|
||||
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
|
||||
)
|
||||
|
||||
####################################
|
||||
# REDIS
|
||||
####################################
|
||||
@@ -396,3 +404,6 @@ else:
|
||||
####################################
|
||||
|
||||
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
|
||||
|
||||
if OFFLINE_MODE:
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
@@ -55,7 +55,7 @@ def handle_peewee_migration(DATABASE_URL):
|
||||
try:
|
||||
# Replace the postgresql:// with postgres:// to handle the peewee migration
|
||||
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
|
||||
migrate_dir = OPEN_WEBUI_DIR / "apps" / "webui" / "internal" / "migrations"
|
||||
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
|
||||
router = Router(db, logger=log, migrate_dir=migrate_dir)
|
||||
router.run()
|
||||
db.close()
|
||||
|
||||
@@ -18,6 +18,8 @@ from typing import Optional
|
||||
from aiocache import cached
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
@@ -27,7 +29,12 @@ from fastapi import (
|
||||
Request,
|
||||
UploadFile,
|
||||
status,
|
||||
applications,
|
||||
BackgroundTasks,
|
||||
)
|
||||
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -51,6 +58,7 @@ from open_webui.routers import (
|
||||
pipelines,
|
||||
tasks,
|
||||
auths,
|
||||
channels,
|
||||
chats,
|
||||
folders,
|
||||
configs,
|
||||
@@ -96,6 +104,7 @@ from open_webui.config import (
|
||||
AUTOMATIC1111_SAMPLER,
|
||||
AUTOMATIC1111_SCHEDULER,
|
||||
COMFYUI_BASE_URL,
|
||||
COMFYUI_API_KEY,
|
||||
COMFYUI_WORKFLOW,
|
||||
COMFYUI_WORKFLOW_NODES,
|
||||
ENABLE_IMAGE_GENERATION,
|
||||
@@ -171,10 +180,13 @@ from open_webui.config import (
|
||||
MOJEEK_SEARCH_API_KEY,
|
||||
GOOGLE_PSE_API_KEY,
|
||||
GOOGLE_PSE_ENGINE_ID,
|
||||
GOOGLE_DRIVE_CLIENT_ID,
|
||||
GOOGLE_DRIVE_API_KEY,
|
||||
ENABLE_RAG_HYBRID_SEARCH,
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
ENABLE_RAG_WEB_SEARCH,
|
||||
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
UPLOAD_DIR,
|
||||
# WebUI
|
||||
WEBUI_AUTH,
|
||||
@@ -187,6 +199,9 @@ from open_webui.config import (
|
||||
ENABLE_SIGNUP,
|
||||
ENABLE_LOGIN_FORM,
|
||||
ENABLE_API_KEY,
|
||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
API_KEY_ALLOWED_ENDPOINTS,
|
||||
ENABLE_CHANNELS,
|
||||
ENABLE_COMMUNITY_SHARING,
|
||||
ENABLE_MESSAGE_RATING,
|
||||
ENABLE_EVALUATION_ARENA_MODELS,
|
||||
@@ -226,6 +241,7 @@ from open_webui.config import (
|
||||
CORS_ALLOW_ORIGIN,
|
||||
DEFAULT_LOCALE,
|
||||
OAUTH_PROVIDERS,
|
||||
WEBUI_URL,
|
||||
# Admin
|
||||
ENABLE_ADMIN_CHAT_ACCESS,
|
||||
ENABLE_ADMIN_EXPORT,
|
||||
@@ -251,13 +267,13 @@ from open_webui.env import (
|
||||
SAFE_MODE,
|
||||
SRC_LOG_LEVELS,
|
||||
VERSION,
|
||||
WEBUI_URL,
|
||||
WEBUI_BUILD_HASH,
|
||||
WEBUI_SECRET_KEY,
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
WEBUI_SESSION_COOKIE_SECURE,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
RESET_CONFIG_ON_START,
|
||||
OFFLINE_MODE,
|
||||
@@ -285,6 +301,7 @@ from open_webui.utils.auth import (
|
||||
from open_webui.utils.oauth import oauth_manager
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
@@ -374,9 +391,15 @@ app.state.OPENAI_MODELS = {}
|
||||
#
|
||||
########################################
|
||||
|
||||
app.state.config.WEBUI_URL = WEBUI_URL
|
||||
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
||||
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
|
||||
|
||||
app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
|
||||
app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = (
|
||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS
|
||||
)
|
||||
app.state.config.API_KEY_ALLOWED_ENDPOINTS = API_KEY_ALLOWED_ENDPOINTS
|
||||
|
||||
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||
|
||||
@@ -393,6 +416,8 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
app.state.config.BANNERS = WEBUI_BANNERS
|
||||
app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
|
||||
|
||||
|
||||
app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS
|
||||
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
||||
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
|
||||
|
||||
@@ -477,6 +502,7 @@ app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
|
||||
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
|
||||
|
||||
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
||||
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
|
||||
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
|
||||
@@ -504,6 +530,22 @@ app.state.rf = None
|
||||
app.state.YOUTUBE_LOADER_TRANSLATION = None
|
||||
|
||||
|
||||
try:
|
||||
app.state.ef = get_ef(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
|
||||
app.state.rf = get_rf(
|
||||
app.state.config.RAG_RERANKING_MODEL,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error updating models: {e}")
|
||||
pass
|
||||
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
@@ -521,21 +563,6 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
try:
|
||||
app.state.ef = get_ef(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
|
||||
app.state.rf = get_rf(
|
||||
app.state.config.RAG_RERANKING_MODEL,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error updating models: {e}")
|
||||
pass
|
||||
|
||||
|
||||
########################################
|
||||
#
|
||||
@@ -557,6 +584,7 @@ app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
|
||||
app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
|
||||
app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
|
||||
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||
app.state.config.COMFYUI_API_KEY = COMFYUI_API_KEY
|
||||
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
|
||||
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
|
||||
|
||||
@@ -722,6 +750,8 @@ app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"])
|
||||
app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"])
|
||||
app.include_router(users.router, prefix="/api/v1/users", tags=["users"])
|
||||
|
||||
|
||||
app.include_router(channels.router, prefix="/api/v1/channels", tags=["channels"])
|
||||
app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"])
|
||||
|
||||
app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
|
||||
@@ -810,11 +840,11 @@ async def chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user),
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
|
||||
tasks = form_data.pop("background_tasks", None)
|
||||
try:
|
||||
model_id = form_data.get("model", None)
|
||||
if model_id not in request.app.state.MODELS:
|
||||
@@ -822,13 +852,26 @@ async def chat_completion(
|
||||
model = request.app.state.MODELS[model_id]
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
form_data, events = await process_chat_payload(request, form_data, user, model)
|
||||
metadata = {
|
||||
"user_id": user.id,
|
||||
"chat_id": form_data.pop("chat_id", None),
|
||||
"message_id": form_data.pop("id", None),
|
||||
"session_id": form_data.pop("session_id", None),
|
||||
"tool_ids": form_data.get("tool_ids", None),
|
||||
"files": form_data.get("files", None),
|
||||
"features": form_data.get("features", None),
|
||||
}
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
form_data, events = await process_chat_payload(
|
||||
request, form_data, metadata, user, model
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -836,10 +879,10 @@ async def chat_completion(
|
||||
)
|
||||
|
||||
try:
|
||||
response = await chat_completion_handler(
|
||||
request, form_data, user, bypass_filter
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
return await process_chat_response(
|
||||
request, response, form_data, user, events, metadata, tasks
|
||||
)
|
||||
return await process_chat_response(response, events)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -878,6 +921,20 @@ async def chat_action(
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/tasks/stop/{task_id}")
|
||||
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
|
||||
try:
|
||||
result = await stop_task(task_id) # Use the function from tasks.py
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/api/tasks")
|
||||
async def list_tasks_endpoint(user=Depends(get_verified_user)):
|
||||
return {"tasks": list_tasks()} # Use the function from tasks.py
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Config Endpoints
|
||||
@@ -925,9 +982,12 @@ async def get_app_config(request: Request):
|
||||
"enable_api_key": app.state.config.ENABLE_API_KEY,
|
||||
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
||||
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
||||
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
|
||||
**(
|
||||
{
|
||||
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
||||
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
|
||||
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
|
||||
@@ -938,6 +998,10 @@ async def get_app_config(request: Request):
|
||||
else {}
|
||||
),
|
||||
},
|
||||
"google_drive": {
|
||||
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
|
||||
"api_key": GOOGLE_DRIVE_API_KEY.value,
|
||||
},
|
||||
**(
|
||||
{
|
||||
"default_models": app.state.config.DEFAULT_MODELS,
|
||||
@@ -1082,9 +1146,9 @@ async def get_opensearch_xml():
|
||||
<ShortName>{WEBUI_NAME}</ShortName>
|
||||
<Description>Search {WEBUI_NAME}</Description>
|
||||
<InputEncoding>UTF-8</InputEncoding>
|
||||
<Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/static/favicon.png</Image>
|
||||
<Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
|
||||
<moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
|
||||
<Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image>
|
||||
<Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/>
|
||||
<moz:SearchForm>{app.state.config.WEBUI_URL}</moz:SearchForm>
|
||||
</OpenSearchDescription>
|
||||
"""
|
||||
return Response(content=xml_content, media_type="application/xml")
|
||||
@@ -1104,6 +1168,19 @@ async def healthcheck_with_db():
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
|
||||
|
||||
|
||||
def swagger_ui_html(*args, **kwargs):
|
||||
return get_swagger_ui_html(
|
||||
*args,
|
||||
**kwargs,
|
||||
swagger_js_url="/static/swagger-ui/swagger-ui-bundle.js",
|
||||
swagger_css_url="/static/swagger-ui/swagger-ui.css",
|
||||
swagger_favicon_url="/static/swagger-ui/favicon.png",
|
||||
)
|
||||
|
||||
|
||||
applications.get_swagger_ui_html = swagger_ui_html
|
||||
|
||||
if os.path.exists(FRONTEND_BUILD_DIR):
|
||||
mimetypes.add_type("text/javascript", ".js")
|
||||
app.mount(
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Update message & channel tables
|
||||
|
||||
Revision ID: 3781e22d8b01
|
||||
Revises: 7826ab40b532
|
||||
Create Date: 2024-12-30 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "3781e22d8b01"
|
||||
down_revision = "7826ab40b532"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add 'type' column to the 'channel' table
|
||||
op.add_column(
|
||||
"channel",
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Text(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Add 'parent_id' column to the 'message' table for threads
|
||||
op.add_column(
|
||||
"message",
|
||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"message_reaction",
|
||||
sa.Column(
|
||||
"id", sa.Text(), nullable=False, primary_key=True, unique=True
|
||||
), # Unique reaction ID
|
||||
sa.Column("user_id", sa.Text(), nullable=False), # User who reacted
|
||||
sa.Column(
|
||||
"message_id", sa.Text(), nullable=False
|
||||
), # Message that was reacted to
|
||||
sa.Column(
|
||||
"name", sa.Text(), nullable=False
|
||||
), # Reaction name (e.g. "thumbs_up")
|
||||
sa.Column(
|
||||
"created_at", sa.BigInteger(), nullable=True
|
||||
), # Timestamp of when the reaction was added
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"channel_member",
|
||||
sa.Column(
|
||||
"id", sa.Text(), nullable=False, primary_key=True, unique=True
|
||||
), # Record ID for the membership row
|
||||
sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel
|
||||
sa.Column("user_id", sa.Text(), nullable=False), # Associated user
|
||||
sa.Column(
|
||||
"created_at", sa.BigInteger(), nullable=True
|
||||
), # Timestamp of when the user joined the channel
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Revert 'type' column addition to the 'channel' table
|
||||
op.drop_column("channel", "type")
|
||||
op.drop_column("message", "parent_id")
|
||||
op.drop_table("message_reaction")
|
||||
op.drop_table("channel_member")
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Add channel table
|
||||
|
||||
Revision ID: 57c599a3cb57
|
||||
Revises: 922e7a387820
|
||||
Create Date: 2024-12-22 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "57c599a3cb57"
|
||||
down_revision = "922e7a387820"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"channel",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("name", sa.Text()),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"message",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("channel_id", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text()),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("channel")
|
||||
|
||||
op.drop_table("message")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Update file table
|
||||
|
||||
Revision ID: 7826ab40b532
|
||||
Revises: 57c599a3cb57
|
||||
Create Date: 2024-12-23 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "7826ab40b532"
|
||||
down_revision = "57c599a3cb57"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"file",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("file", "access_control")
|
||||
136
backend/open_webui/models/channels.py
Normal file
136
backend/open_webui/models/channels.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
# Channel DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Channel(Base):
|
||||
__tablename__ = "channel"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
type = Column(Text, nullable=True)
|
||||
|
||||
name = Column(Text)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
access_control = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class ChannelModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
type: Optional[str] = None
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class ChannelForm(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
|
||||
class ChannelTable:
|
||||
def insert_new_channel(
|
||||
self, type: Optional[str], form_data: ChannelForm, user_id: str
|
||||
) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = ChannelModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"type": type,
|
||||
"name": form_data.name.lower(),
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time_ns()),
|
||||
"updated_at": int(time.time_ns()),
|
||||
}
|
||||
)
|
||||
|
||||
new_channel = Channel(**channel.model_dump())
|
||||
|
||||
db.add(new_channel)
|
||||
db.commit()
|
||||
return channel
|
||||
|
||||
def get_channels(self) -> list[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channels = db.query(Channel).all()
|
||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||
|
||||
def get_channels_by_user_id(
|
||||
self, user_id: str, permission: str = "read"
|
||||
) -> list[ChannelModel]:
|
||||
channels = self.get_channels()
|
||||
return [
|
||||
channel
|
||||
for channel in channels
|
||||
if channel.user_id == user_id
|
||||
or has_access(user_id, permission, channel.access_control)
|
||||
]
|
||||
|
||||
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
return ChannelModel.model_validate(channel) if channel else None
|
||||
|
||||
def update_channel_by_id(
|
||||
self, id: str, form_data: ChannelForm
|
||||
) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
channel.name = form_data.name
|
||||
channel.data = form_data.data
|
||||
channel.meta = form_data.meta
|
||||
channel.access_control = form_data.access_control
|
||||
channel.updated_at = int(time.time_ns())
|
||||
|
||||
db.commit()
|
||||
return ChannelModel.model_validate(channel) if channel else None
|
||||
|
||||
def delete_channel_by_id(self, id: str):
|
||||
with get_db() as db:
|
||||
db.query(Channel).filter(Channel.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
Channels = ChannelTable()
|
||||
@@ -168,6 +168,100 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
chat["title"] = title
|
||||
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def update_chat_tags_by_id(
|
||||
self, id: str, tags: list[str], user
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
self.delete_all_tags_by_id_and_user_id(id, user.id)
|
||||
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
|
||||
for tag_name in tags:
|
||||
if tag_name.lower() == "none":
|
||||
continue
|
||||
|
||||
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
|
||||
return self.get_chat_by_id(id)
|
||||
|
||||
def get_chat_title_by_id(self, id: str) -> Optional[str]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
return chat.chat.get("title", "New Chat")
|
||||
|
||||
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
return chat.chat.get("history", {}).get("messages", {}) or {}
|
||||
|
||||
def get_message_by_id_and_message_id(
|
||||
self, id: str, message_id: str
|
||||
) -> Optional[dict]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
return chat.chat.get("history", {}).get("messages", {}).get(message_id, {})
|
||||
|
||||
def upsert_message_to_chat_by_id_and_message_id(
|
||||
self, id: str, message_id: str, message: dict
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
history = chat.get("history", {})
|
||||
|
||||
if message_id in history.get("messages", {}):
|
||||
history["messages"][message_id] = {
|
||||
**history["messages"][message_id],
|
||||
**message,
|
||||
}
|
||||
else:
|
||||
history["messages"][message_id] = message
|
||||
|
||||
history["currentId"] = message_id
|
||||
|
||||
chat["history"] = history
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def add_message_status_to_chat_by_id_and_message_id(
|
||||
self, id: str, message_id: str, status: dict
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
history = chat.get("history", {})
|
||||
|
||||
if message_id in history.get("messages", {}):
|
||||
status_history = history["messages"][message_id].get("statusHistory", [])
|
||||
status_history.append(status)
|
||||
history["messages"][message_id]["statusHistory"] = status_history
|
||||
|
||||
chat["history"] = history
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
# Get the existing chat to share
|
||||
@@ -375,6 +469,8 @@ class ChatTable:
|
||||
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# it is possible that the shared link was deleted. hence,
|
||||
# we check if the chat is still shared by checkng if a chat with the share_id exists
|
||||
chat = db.query(Chat).filter_by(share_id=id).first()
|
||||
|
||||
if chat:
|
||||
|
||||
@@ -27,6 +27,8 @@ class File(Base):
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
access_control = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
@@ -44,6 +46,8 @@ class FileModel(BaseModel):
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
created_at: Optional[int] # timestamp in epoch
|
||||
updated_at: Optional[int] # timestamp in epoch
|
||||
|
||||
@@ -90,6 +94,7 @@ class FileForm(BaseModel):
|
||||
path: str
|
||||
data: dict = {}
|
||||
meta: dict = {}
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
|
||||
class FilesTable:
|
||||
|
||||
@@ -146,6 +146,13 @@ class GroupTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
|
||||
group = self.get_group_by_id(id)
|
||||
if group:
|
||||
return group.user_ids
|
||||
else:
|
||||
return None
|
||||
|
||||
def update_group_by_id(
|
||||
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
||||
) -> Optional[GroupModel]:
|
||||
|
||||
279
backend/open_webui/models/messages.py
Normal file
279
backend/open_webui/models/messages.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
# Message DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class MessageReaction(Base):
|
||||
__tablename__ = "message_reaction"
|
||||
id = Column(Text, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
message_id = Column(Text)
|
||||
name = Column(Text)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
class MessageReactionModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
message_id: str
|
||||
name: str
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "message"
|
||||
id = Column(Text, primary_key=True)
|
||||
|
||||
user_id = Column(Text)
|
||||
channel_id = Column(Text, nullable=True)
|
||||
|
||||
parent_id = Column(Text, nullable=True)
|
||||
|
||||
content = Column(Text)
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger) # time_ns
|
||||
updated_at = Column(BigInteger) # time_ns
|
||||
|
||||
|
||||
class MessageModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
channel_id: Optional[str] = None
|
||||
|
||||
parent_id: Optional[str] = None
|
||||
|
||||
content: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class MessageForm(BaseModel):
|
||||
content: str
|
||||
parent_id: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
|
||||
class Reactions(BaseModel):
|
||||
name: str
|
||||
user_ids: list[str]
|
||||
count: int
|
||||
|
||||
|
||||
class MessageResponse(MessageModel):
|
||||
latest_reply_at: Optional[int]
|
||||
reply_count: int
|
||||
reactions: list[Reactions]
|
||||
|
||||
|
||||
class MessageTable:
|
||||
def insert_new_message(
|
||||
self, form_data: MessageForm, channel_id: str, user_id: str
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
ts = int(time.time_ns())
|
||||
message = MessageModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"channel_id": channel_id,
|
||||
"parent_id": form_data.parent_id,
|
||||
"content": form_data.content,
|
||||
"data": form_data.data,
|
||||
"meta": form_data.meta,
|
||||
"created_at": ts,
|
||||
"updated_at": ts,
|
||||
}
|
||||
)
|
||||
|
||||
result = Message(**message.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
return MessageModel.model_validate(result) if result else None
|
||||
|
||||
def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
|
||||
with get_db() as db:
|
||||
message = db.get(Message, id)
|
||||
if not message:
|
||||
return None
|
||||
|
||||
reactions = self.get_reactions_by_message_id(id)
|
||||
replies = self.get_replies_by_message_id(id)
|
||||
|
||||
return MessageResponse(
|
||||
**{
|
||||
**MessageModel.model_validate(message).model_dump(),
|
||||
"latest_reply_at": replies[0].created_at if replies else None,
|
||||
"reply_count": len(replies),
|
||||
"reactions": reactions,
|
||||
}
|
||||
)
|
||||
|
||||
def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
|
||||
with get_db() as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(parent_id=id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return [MessageModel.model_validate(message) for message in all_messages]
|
||||
|
||||
def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
message.user_id
|
||||
for message in db.query(Message).filter_by(parent_id=id).all()
|
||||
]
|
||||
|
||||
def get_messages_by_channel_id(
|
||||
self, channel_id: str, skip: int = 0, limit: int = 50
|
||||
) -> list[MessageModel]:
|
||||
with get_db() as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(channel_id=channel_id, parent_id=None)
|
||||
.order_by(Message.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return [MessageModel.model_validate(message) for message in all_messages]
|
||||
|
||||
def get_messages_by_parent_id(
|
||||
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
|
||||
) -> list[MessageModel]:
|
||||
with get_db() as db:
|
||||
message = db.get(Message, parent_id)
|
||||
|
||||
if not message:
|
||||
return []
|
||||
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(channel_id=channel_id, parent_id=parent_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# If length of all_messages is less than limit, then add the parent message
|
||||
if len(all_messages) < limit:
|
||||
all_messages.append(message)
|
||||
|
||||
return [MessageModel.model_validate(message) for message in all_messages]
|
||||
|
||||
def update_message_by_id(
|
||||
self, id: str, form_data: MessageForm
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
message = db.get(Message, id)
|
||||
message.content = form_data.content
|
||||
message.data = form_data.data
|
||||
message.meta = form_data.meta
|
||||
message.updated_at = int(time.time_ns())
|
||||
db.commit()
|
||||
db.refresh(message)
|
||||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def add_reaction_to_message(
|
||||
self, id: str, user_id: str, name: str
|
||||
) -> Optional[MessageReactionModel]:
|
||||
with get_db() as db:
|
||||
reaction_id = str(uuid.uuid4())
|
||||
reaction = MessageReactionModel(
|
||||
id=reaction_id,
|
||||
user_id=user_id,
|
||||
message_id=id,
|
||||
name=name,
|
||||
created_at=int(time.time_ns()),
|
||||
)
|
||||
result = MessageReaction(**reaction.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
return MessageReactionModel.model_validate(result) if result else None
|
||||
|
||||
def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
|
||||
with get_db() as db:
|
||||
all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
|
||||
|
||||
reactions = {}
|
||||
for reaction in all_reactions:
|
||||
if reaction.name not in reactions:
|
||||
reactions[reaction.name] = {
|
||||
"name": reaction.name,
|
||||
"user_ids": [],
|
||||
"count": 0,
|
||||
}
|
||||
reactions[reaction.name]["user_ids"].append(reaction.user_id)
|
||||
reactions[reaction.name]["count"] += 1
|
||||
|
||||
return [Reactions(**reaction) for reaction in reactions.values()]
|
||||
|
||||
def remove_reaction_by_id_and_user_id_and_name(
|
||||
self, id: str, user_id: str, name: str
|
||||
) -> bool:
|
||||
with get_db() as db:
|
||||
db.query(MessageReaction).filter_by(
|
||||
message_id=id, user_id=user_id, name=name
|
||||
).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_reactions_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
db.query(MessageReaction).filter_by(message_id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_replies_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
db.query(Message).filter_by(parent_id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_message_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
db.query(Message).filter_by(id=id).delete()
|
||||
|
||||
# Delete all reactions to this message
|
||||
db.query(MessageReaction).filter_by(message_id=id).delete()
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
Messages = MessageTable()
|
||||
@@ -70,6 +70,13 @@ class UserResponse(BaseModel):
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
class UserNameResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
class UserRoleUpdateForm(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
@@ -147,13 +154,25 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_users(self, skip: int = 0, limit: int = 50) -> list[UserModel]:
|
||||
def get_users(
|
||||
self, skip: Optional[int] = None, limit: Optional[int] = None
|
||||
) -> list[UserModel]:
|
||||
with get_db() as db:
|
||||
users = (
|
||||
db.query(User)
|
||||
# .offset(skip).limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
query = db.query(User).order_by(User.created_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
users = query.all()
|
||||
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
|
||||
with get_db() as db:
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_num_users(self) -> Optional[int]:
|
||||
@@ -168,6 +187,22 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
|
||||
if user.settings is None:
|
||||
return None
|
||||
else:
|
||||
return (
|
||||
user.settings.get("ui", {})
|
||||
.get("notifications", {})
|
||||
.get("webhook_url", None)
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
@@ -14,7 +14,7 @@ from langchain_core.documents import Document
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@@ -70,7 +70,9 @@ def query_doc(
|
||||
limit=k,
|
||||
)
|
||||
|
||||
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||
if result:
|
||||
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
@@ -373,6 +375,9 @@ def get_model_path(model: str, update_model: bool = False):
|
||||
|
||||
local_files_only = not update_model
|
||||
|
||||
if OFFLINE_MODE:
|
||||
local_files_only = True
|
||||
|
||||
snapshot_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy import (
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
MetaData,
|
||||
select,
|
||||
text,
|
||||
Text,
|
||||
@@ -19,9 +20,9 @@ from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import PGVECTOR_DB_URL
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
|
||||
VECTOR_LENGTH = 1536
|
||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@@ -56,6 +57,9 @@ class PgvectorClient:
|
||||
# Ensure the pgvector extension is available
|
||||
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||
|
||||
# Check vector length consistency
|
||||
self.check_vector_length()
|
||||
|
||||
# Create the tables if they do not exist
|
||||
# Base.metadata.create_all requires a bind (engine or connection)
|
||||
# Get the connection from the session
|
||||
@@ -82,6 +86,38 @@ class PgvectorClient:
|
||||
print(f"Error during initialization: {e}")
|
||||
raise
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
"""
|
||||
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
||||
Raises an exception if there is a mismatch.
|
||||
"""
|
||||
metadata = MetaData()
|
||||
metadata.reflect(bind=self.session.bind, only=["document_chunk"])
|
||||
|
||||
if "document_chunk" in metadata.tables:
|
||||
document_chunk_table = metadata.tables["document_chunk"]
|
||||
if "vector" in document_chunk_table.columns:
|
||||
vector_column = document_chunk_table.columns["vector"]
|
||||
vector_type = vector_column.type
|
||||
if isinstance(vector_type, Vector):
|
||||
db_vector_length = vector_type.dim
|
||||
if db_vector_length != VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
||||
"Cannot change vector size after initialization without migrating the data."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"The 'vector' column exists but is not of type 'Vector'."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"The 'vector' column does not exist in the 'document_chunk' table."
|
||||
)
|
||||
else:
|
||||
# Table does not exist yet; no action needed
|
||||
pass
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
# Adjust vector to have length VECTOR_LENGTH
|
||||
current_length = len(vector)
|
||||
|
||||
@@ -683,7 +683,7 @@
|
||||
"age": "October 29, 2022",
|
||||
"extra_snippets": [
|
||||
"You can pass many options to the configure script; run ./configure --help to find out more. On macOS case-insensitive file systems and on Cygwin, the executable is called python.exe; elsewhere it's just python.",
|
||||
"Building a complete Python installation requires the use of various additional third-party libraries, depending on your build platform and configure options. Not all standard library modules are buildable or useable on all platforms. Refer to the Install dependencies section of the Developer Guide for current detailed information on dependencies for various Linux distributions and macOS.",
|
||||
"Building a complete Python installation requires the use of various additional third-party libraries, depending on your build platform and configure options. Not all standard library modules are buildable or usable on all platforms. Refer to the Install dependencies section of the Developer Guide for current detailed information on dependencies for various Linux distributions and macOS.",
|
||||
"To get an optimized build of Python, configure --enable-optimizations before you run make. This sets the default make targets up to enable Profile Guided Optimization (PGO) and may be used to auto-enable Link Time Optimization (LTO) on some platforms. For more details, see the sections below.",
|
||||
"Copyright © 2001-2024 Python Software Foundation. All rights reserved."
|
||||
]
|
||||
|
||||
@@ -82,15 +82,15 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
|
||||
|
||||
def get_web_loader(
|
||||
url: Union[str, Sequence[str]],
|
||||
urls: Union[str, Sequence[str]],
|
||||
verify_ssl: bool = True,
|
||||
requests_per_second: int = 2,
|
||||
):
|
||||
# Check if the URL is valid
|
||||
if not validate_url(url):
|
||||
if not validate_url(urls):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return SafeWebBaseLoader(
|
||||
url,
|
||||
urls,
|
||||
verify_ssl=verify_ssl,
|
||||
requests_per_second=requests_per_second,
|
||||
continue_on_failure=True,
|
||||
|
||||
@@ -218,7 +218,7 @@ async def update_audio_config(
|
||||
}
|
||||
|
||||
|
||||
def load_speech_pipeline():
|
||||
def load_speech_pipeline(request):
|
||||
from transformers import pipeline
|
||||
from datasets import load_dataset
|
||||
|
||||
@@ -236,7 +236,11 @@ def load_speech_pipeline():
|
||||
@router.post("/speech")
|
||||
async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
body = await request.body()
|
||||
name = hashlib.sha256(body).hexdigest()
|
||||
name = hashlib.sha256(
|
||||
body
|
||||
+ str(request.app.state.config.TTS_ENGINE).encode("utf-8")
|
||||
+ str(request.app.state.config.TTS_MODEL).encode("utf-8")
|
||||
).hexdigest()
|
||||
|
||||
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
||||
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
||||
@@ -256,10 +260,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
payload["model"] = request.app.state.config.TTS_MODEL
|
||||
|
||||
try:
|
||||
# print(payload)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||
data=payload,
|
||||
json=payload,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
|
||||
@@ -281,7 +286,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
await f.write(await r.read())
|
||||
|
||||
async with aiofiles.open(file_body_path, "w") as f:
|
||||
await f.write(json.dumps(json.loads(body.decode("utf-8"))))
|
||||
await f.write(json.dumps(payload))
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
@@ -292,6 +297,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
try:
|
||||
if r.status != 200:
|
||||
res = await r.json()
|
||||
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
except Exception:
|
||||
@@ -305,7 +311,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
voice_id = payload.get("voice", "")
|
||||
|
||||
if voice_id not in get_available_voices():
|
||||
if voice_id not in get_available_voices(request):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid voice id",
|
||||
@@ -332,7 +338,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
await f.write(await r.read())
|
||||
|
||||
async with aiofiles.open(file_body_path, "w") as f:
|
||||
await f.write(json.dumps(json.loads(body.decode("utf-8"))))
|
||||
await f.write(json.dumps(payload))
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
@@ -384,6 +390,9 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(await r.read())
|
||||
|
||||
async with aiofiles.open(file_body_path, "w") as f:
|
||||
await f.write(json.dumps(payload))
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
except Exception as e:
|
||||
@@ -414,7 +423,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
import torch
|
||||
import soundfile as sf
|
||||
|
||||
load_speech_pipeline()
|
||||
load_speech_pipeline(request)
|
||||
|
||||
embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
|
||||
|
||||
@@ -436,8 +445,9 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(json.loads(body.decode("utf-8")), f)
|
||||
|
||||
async with aiofiles.open(file_body_path, "w") as f:
|
||||
await f.write(json.dumps(payload))
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
@@ -547,7 +547,6 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
try:
|
||||
print(form_data)
|
||||
hashed = get_password_hash(form_data.password)
|
||||
user = Auths.insert_new_auth(
|
||||
form_data.email.lower(),
|
||||
@@ -614,8 +613,12 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||
async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
||||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
@@ -625,8 +628,12 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
||||
|
||||
class AdminConfig(BaseModel):
|
||||
SHOW_ADMIN_DETAILS: bool
|
||||
WEBUI_URL: str
|
||||
ENABLE_SIGNUP: bool
|
||||
ENABLE_API_KEY: bool
|
||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool
|
||||
API_KEY_ALLOWED_ENDPOINTS: str
|
||||
ENABLE_CHANNELS: bool
|
||||
DEFAULT_USER_ROLE: str
|
||||
JWT_EXPIRES_IN: str
|
||||
ENABLE_COMMUNITY_SHARING: bool
|
||||
@@ -638,8 +645,18 @@ async def update_admin_config(
|
||||
request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
|
||||
request.app.state.config.WEBUI_URL = form_data.WEBUI_URL
|
||||
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
|
||||
|
||||
request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
|
||||
request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = (
|
||||
form_data.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS
|
||||
)
|
||||
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS = (
|
||||
form_data.API_KEY_ALLOWED_ENDPOINTS
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
|
||||
|
||||
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
||||
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
||||
@@ -657,8 +674,12 @@ async def update_admin_config(
|
||||
|
||||
return {
|
||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
||||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
|
||||
710
backend/open_webui/routers/channels.py
Normal file
710
backend/open_webui/routers/channels.py
Normal file
@@ -0,0 +1,710 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.socket.main import sio, get_user_ids_from_room
|
||||
from open_webui.models.users import Users, UserNameResponse
|
||||
|
||||
from open_webui.models.channels import Channels, ChannelModel, ChannelForm
|
||||
from open_webui.models.messages import (
|
||||
Messages,
|
||||
MessageModel,
|
||||
MessageResponse,
|
||||
MessageForm,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, get_users_with_access
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetChatList
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ChannelModel])
|
||||
async def get_channels(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
return Channels.get_channels()
|
||||
else:
|
||||
return Channels.get_channels_by_user_id(user.id)
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewChannel
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[ChannelModel])
|
||||
async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
channel = Channels.insert_new_channel(None, form_data, user.id)
|
||||
return ChannelModel(**channel.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChannelById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[ChannelModel])
|
||||
async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
return ChannelModel(**channel.model_dump())
|
||||
|
||||
|
||||
############################
|
||||
# UpdateChannelById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/update", response_model=Optional[ChannelModel])
|
||||
async def update_channel_by_id(
|
||||
id: str, form_data: ChannelForm, user=Depends(get_admin_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
try:
|
||||
channel = Channels.update_channel_by_id(id, form_data)
|
||||
return ChannelModel(**channel.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteChannelById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}/delete", response_model=bool)
|
||||
async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
try:
|
||||
Channels.delete_channel_by_id(id)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChannelMessages
|
||||
############################
|
||||
|
||||
|
||||
class MessageUserResponse(MessageResponse):
|
||||
user: UserNameResponse
|
||||
|
||||
|
||||
@router.get("/{id}/messages", response_model=list[MessageUserResponse])
|
||||
async def get_channel_messages(
|
||||
id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message_list = Messages.get_messages_by_channel_id(id, skip, limit)
|
||||
users = {}
|
||||
|
||||
messages = []
|
||||
for message in message_list:
|
||||
if message.user_id not in users:
|
||||
user = Users.get_user_by_id(message.user_id)
|
||||
users[message.user_id] = user
|
||||
|
||||
replies = Messages.get_replies_by_message_id(message.id)
|
||||
latest_reply_at = replies[0].created_at if replies else None
|
||||
|
||||
messages.append(
|
||||
MessageUserResponse(
|
||||
**{
|
||||
**message.model_dump(),
|
||||
"reply_count": len(replies),
|
||||
"latest_reply_at": latest_reply_at,
|
||||
"reactions": Messages.get_reactions_by_message_id(message.id),
|
||||
"user": UserNameResponse(**users[message.user_id].model_dump()),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
############################
|
||||
# PostNewMessage
|
||||
############################
|
||||
|
||||
|
||||
async def send_notification(webui_url, channel, message, active_user_ids):
|
||||
users = get_users_with_access("read", channel.access_control)
|
||||
|
||||
for user in users:
|
||||
if user.id in active_user_ids:
|
||||
continue
|
||||
else:
|
||||
if user.settings:
|
||||
webhook_url = user.settings.ui.get("notifications", {}).get(
|
||||
"webhook_url", None
|
||||
)
|
||||
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
webhook_url,
|
||||
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
|
||||
{
|
||||
"action": "channel",
|
||||
"message": message.content,
|
||||
"title": channel.name,
|
||||
"url": f"{webui_url}/channels/{channel.id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
|
||||
async def post_new_message(
|
||||
request: Request,
|
||||
id: str,
|
||||
form_data: MessageForm,
|
||||
background_tasks: BackgroundTasks,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
message = Messages.insert_new_message(form_data, channel.id, user.id)
|
||||
|
||||
if message:
|
||||
event_data = {
|
||||
"channel_id": channel.id,
|
||||
"message_id": message.id,
|
||||
"data": {
|
||||
"type": "message",
|
||||
"data": MessageUserResponse(
|
||||
**{
|
||||
**message.model_dump(),
|
||||
"reply_count": 0,
|
||||
"latest_reply_at": None,
|
||||
"reactions": Messages.get_reactions_by_message_id(
|
||||
message.id
|
||||
),
|
||||
"user": UserNameResponse(**user.model_dump()),
|
||||
}
|
||||
).model_dump(),
|
||||
},
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
"channel": channel.model_dump(),
|
||||
}
|
||||
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
event_data,
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
if message.parent_id:
|
||||
# If this message is a reply, emit to the parent message as well
|
||||
parent_message = Messages.get_message_by_id(message.parent_id)
|
||||
|
||||
if parent_message:
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
{
|
||||
"channel_id": channel.id,
|
||||
"message_id": parent_message.id,
|
||||
"data": {
|
||||
"type": "message:reply",
|
||||
"data": MessageUserResponse(
|
||||
**{
|
||||
**parent_message.model_dump(),
|
||||
"user": UserNameResponse(
|
||||
**Users.get_user_by_id(
|
||||
parent_message.user_id
|
||||
).model_dump()
|
||||
),
|
||||
}
|
||||
).model_dump(),
|
||||
},
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
"channel": channel.model_dump(),
|
||||
},
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
|
||||
|
||||
background_tasks.add_task(
|
||||
send_notification,
|
||||
request.app.state.config.WEBUI_URL,
|
||||
channel,
|
||||
message,
|
||||
active_user_ids,
|
||||
)
|
||||
|
||||
return MessageModel(**message.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChannelMessage
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/messages/{message_id}", response_model=Optional[MessageUserResponse])
|
||||
async def get_channel_message(
|
||||
id: str, message_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if message.channel_id != id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
return MessageUserResponse(
|
||||
**{
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(
|
||||
**Users.get_user_by_id(message.user_id).model_dump()
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChannelThreadMessages
|
||||
############################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{id}/messages/{message_id}/thread", response_model=list[MessageUserResponse]
|
||||
)
|
||||
async def get_channel_thread_messages(
|
||||
id: str,
|
||||
message_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit)
|
||||
users = {}
|
||||
|
||||
messages = []
|
||||
for message in message_list:
|
||||
if message.user_id not in users:
|
||||
user = Users.get_user_by_id(message.user_id)
|
||||
users[message.user_id] = user
|
||||
|
||||
messages.append(
|
||||
MessageUserResponse(
|
||||
**{
|
||||
**message.model_dump(),
|
||||
"reply_count": 0,
|
||||
"latest_reply_at": None,
|
||||
"reactions": Messages.get_reactions_by_message_id(message.id),
|
||||
"user": UserNameResponse(**users[message.user_id].model_dump()),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
############################
|
||||
# UpdateMessageById
|
||||
############################
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{id}/messages/{message_id}/update", response_model=Optional[MessageModel]
|
||||
)
|
||||
async def update_message_by_id(
|
||||
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if message.channel_id != id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
message = Messages.update_message_by_id(message_id, form_data)
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
|
||||
if message:
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
{
|
||||
"channel_id": channel.id,
|
||||
"message_id": message.id,
|
||||
"data": {
|
||||
"type": "message:update",
|
||||
"data": MessageUserResponse(
|
||||
**{
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(
|
||||
**user.model_dump()
|
||||
).model_dump(),
|
||||
}
|
||||
).model_dump(),
|
||||
},
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
"channel": channel.model_dump(),
|
||||
},
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
return MessageModel(**message.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# AddReactionToMessage
|
||||
############################
|
||||
|
||||
|
||||
class ReactionForm(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.post("/{id}/messages/{message_id}/reactions/add", response_model=bool)
|
||||
async def add_reaction_to_message(
|
||||
id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if message.channel_id != id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
Messages.add_reaction_to_message(message_id, user.id, form_data.name)
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
{
|
||||
"channel_id": channel.id,
|
||||
"message_id": message.id,
|
||||
"data": {
|
||||
"type": "message:reaction:add",
|
||||
"data": {
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(
|
||||
**Users.get_user_by_id(message.user_id).model_dump()
|
||||
).model_dump(),
|
||||
"name": form_data.name,
|
||||
},
|
||||
},
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
"channel": channel.model_dump(),
|
||||
},
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# RemoveReactionById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/messages/{message_id}/reactions/remove", response_model=bool)
|
||||
async def remove_reaction_by_id_and_user_id_and_name(
|
||||
id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if message.channel_id != id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
Messages.remove_reaction_by_id_and_user_id_and_name(
|
||||
message_id, user.id, form_data.name
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
{
|
||||
"channel_id": channel.id,
|
||||
"message_id": message.id,
|
||||
"data": {
|
||||
"type": "message:reaction:remove",
|
||||
"data": {
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(
|
||||
**Users.get_user_by_id(message.user_id).model_dump()
|
||||
).model_dump(),
|
||||
"name": form_data.name,
|
||||
},
|
||||
},
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
"channel": channel.model_dump(),
|
||||
},
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteMessageById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}/messages/{message_id}/delete", response_model=bool)
|
||||
async def delete_message_by_id(
|
||||
id: str, message_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if message.channel_id != id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
Messages.delete_message_by_id(message_id)
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
{
|
||||
"channel_id": channel.id,
|
||||
"message_id": message.id,
|
||||
"data": {
|
||||
"type": "message:delete",
|
||||
"data": {
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
},
|
||||
},
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
"channel": channel.model_dump(),
|
||||
},
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
if message.parent_id:
|
||||
# If this message is a reply, emit to the parent message as well
|
||||
parent_message = Messages.get_message_by_id(message.parent_id)
|
||||
|
||||
if parent_message:
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
{
|
||||
"channel_id": channel.id,
|
||||
"message_id": parent_message.id,
|
||||
"data": {
|
||||
"type": "message:reply",
|
||||
"data": MessageUserResponse(
|
||||
**{
|
||||
**parent_message.model_dump(),
|
||||
"user": UserNameResponse(
|
||||
**Users.get_user_by_id(
|
||||
parent_message.user_id
|
||||
).model_dump()
|
||||
),
|
||||
}
|
||||
).model_dump(),
|
||||
},
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
"channel": channel.model_dump(),
|
||||
},
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
@@ -463,6 +463,30 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# CloneSharedChatById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
|
||||
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_share_id(id)
|
||||
if chat:
|
||||
updated_chat = {
|
||||
**chat.chat,
|
||||
"originalChatId": chat.id,
|
||||
"branchPointMessageId": chat.chat["history"]["currentId"],
|
||||
"title": f"Clone of {chat.title}",
|
||||
}
|
||||
|
||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# ArchiveChat
|
||||
############################
|
||||
|
||||
@@ -226,9 +226,16 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
# Handle Unicode filenames
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename) # RFC5987 encoding
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
}
|
||||
|
||||
headers = {}
|
||||
if file.meta.get("content_type") not in [
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
]:
|
||||
headers = {
|
||||
**headers,
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
}
|
||||
|
||||
return FileResponse(file_path, headers=headers)
|
||||
|
||||
@@ -341,7 +348,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
result = Files.delete_file_by_id(id)
|
||||
if result:
|
||||
try:
|
||||
Storage.delete_file(file.filename)
|
||||
Storage.delete_file(file.path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error deleting files")
|
||||
|
||||
@@ -56,6 +56,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
},
|
||||
"comfyui": {
|
||||
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
|
||||
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
|
||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
@@ -77,6 +78,7 @@ class Automatic1111ConfigForm(BaseModel):
|
||||
|
||||
class ComfyUIConfigForm(BaseModel):
|
||||
COMFYUI_BASE_URL: str
|
||||
COMFYUI_API_KEY: str
|
||||
COMFYUI_WORKFLOW: str
|
||||
COMFYUI_WORKFLOW_NODES: list[dict]
|
||||
|
||||
@@ -148,6 +150,7 @@ async def update_config(
|
||||
},
|
||||
"comfyui": {
|
||||
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
|
||||
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
|
||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
@@ -197,7 +200,7 @@ def set_image_model(request: Request, model: str):
|
||||
log.info(f"Setting image model to {model}")
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL = model
|
||||
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
|
||||
api_auth = get_automatic1111_api_auth()
|
||||
api_auth = get_automatic1111_api_auth(request)
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
headers={"authorization": api_auth},
|
||||
@@ -233,7 +236,7 @@ def get_image_model(request):
|
||||
try:
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
headers={"authorization": get_automatic1111_api_auth()},
|
||||
headers={"authorization": get_automatic1111_api_auth(request)},
|
||||
)
|
||||
options = r.json()
|
||||
return options["sd_model_checkpoint"]
|
||||
@@ -298,8 +301,12 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
# TODO - get models from comfyui
|
||||
headers = {
|
||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
|
||||
headers=headers,
|
||||
)
|
||||
info = r.json()
|
||||
|
||||
@@ -347,7 +354,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
):
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
|
||||
headers={"authorization": get_automatic1111_api_auth()},
|
||||
headers={"authorization": get_automatic1111_api_auth(request)},
|
||||
)
|
||||
models = r.json()
|
||||
return list(
|
||||
@@ -521,6 +528,7 @@ async def image_generations(
|
||||
form_data,
|
||||
user.id,
|
||||
request.app.state.config.COMFYUI_BASE_URL,
|
||||
request.app.state.config.COMFYUI_API_KEY,
|
||||
)
|
||||
log.debug(f"res: {res}")
|
||||
|
||||
@@ -570,7 +578,7 @@ async def image_generations(
|
||||
requests.post,
|
||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
|
||||
json=data,
|
||||
headers={"authorization": get_automatic1111_api_auth()},
|
||||
headers={"authorization": get_automatic1111_api_auth(request)},
|
||||
)
|
||||
|
||||
res = r.json()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
import logging
|
||||
@@ -12,11 +11,16 @@ from open_webui.models.knowledge import (
|
||||
)
|
||||
from open_webui.models.files import Files, FileModel
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.routers.retrieval import process_file, ProcessFileForm
|
||||
from open_webui.routers.retrieval import (
|
||||
process_file,
|
||||
ProcessFileForm,
|
||||
process_files_batch,
|
||||
BatchProcessFilesForm,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.auth import get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
|
||||
|
||||
@@ -415,13 +419,6 @@ def remove_file_from_knowledge_by_id(
|
||||
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
||||
)
|
||||
|
||||
result = VECTOR_DB_CLIENT.query(
|
||||
collection_name=knowledge.id,
|
||||
filter={"file_id": form_data.file_id},
|
||||
)
|
||||
|
||||
Files.delete_file_by_id(form_data.file_id)
|
||||
|
||||
if knowledge:
|
||||
data = knowledge.data or {}
|
||||
file_ids = data.get("file_ids", [])
|
||||
@@ -514,3 +511,86 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
|
||||
|
||||
return knowledge
|
||||
|
||||
|
||||
############################
|
||||
# AddFilesToKnowledge
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
|
||||
def add_files_to_knowledge_batch(
|
||||
request: Request,
|
||||
id: str,
|
||||
form_data: list[KnowledgeFileIdForm],
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
"""
|
||||
Add multiple files to a knowledge base
|
||||
"""
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
# Get files content
|
||||
print(f"files/batch/add - {len(form_data)} files")
|
||||
files: List[FileModel] = []
|
||||
for form in form_data:
|
||||
file = Files.get_file_by_id(form.file_id)
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File {form.file_id} not found",
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
# Process files
|
||||
try:
|
||||
result = process_files_batch(
|
||||
request=request,
|
||||
form_data=BatchProcessFilesForm(files=files, collection_name=id),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"add_files_to_knowledge_batch: Exception occurred: {e}", exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
# Add successful files to knowledge base
|
||||
data = knowledge.data or {}
|
||||
existing_file_ids = data.get("file_ids", [])
|
||||
|
||||
# Only add files that were successfully processed
|
||||
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
|
||||
for file_id in successful_file_ids:
|
||||
if file_id not in existing_file_ids:
|
||||
existing_file_ids.append(file_id)
|
||||
|
||||
data["file_ids"] = existing_file_ids
|
||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
||||
|
||||
# If there were any errors, include them in the response
|
||||
if result.errors:
|
||||
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Files.get_files_by_ids(existing_file_ids),
|
||||
warnings={
|
||||
"message": "Some files failed to process",
|
||||
"errors": error_details,
|
||||
},
|
||||
)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
|
||||
)
|
||||
|
||||
@@ -82,6 +82,16 @@ async def send_get_request(url, key=None):
|
||||
return None
|
||||
|
||||
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def send_post_request(
|
||||
url: str,
|
||||
payload: Union[str, bytes],
|
||||
@@ -89,14 +99,6 @@ async def send_post_request(
|
||||
key: Optional[str] = None,
|
||||
content_type: Optional[str] = None,
|
||||
):
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
r = None
|
||||
try:
|
||||
@@ -917,7 +919,7 @@ class ChatMessage(BaseModel):
|
||||
class GenerateChatCompletionForm(BaseModel):
|
||||
model: str
|
||||
messages: list[ChatMessage]
|
||||
format: Optional[str] = None
|
||||
format: Optional[dict] = None
|
||||
options: Optional[dict] = None
|
||||
template: Optional[str] = None
|
||||
stream: Optional[bool] = True
|
||||
|
||||
@@ -533,6 +533,9 @@ async def generate_chat_completion(
|
||||
user=Depends(get_verified_user),
|
||||
bypass_filter: Optional[bool] = False,
|
||||
):
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
idx = 0
|
||||
payload = {**form_data}
|
||||
if "metadata" in payload:
|
||||
@@ -545,6 +548,7 @@ async def generate_chat_completion(
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
model_id = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
@@ -604,14 +608,13 @@ async def generate_chat_completion(
|
||||
if is_o1:
|
||||
payload = openai_o1_handler(payload)
|
||||
elif "api.openai.com" not in url:
|
||||
# Remove "max_tokens" from the payload for backward compatibility
|
||||
if "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||
if "max_completion_tokens" in payload:
|
||||
payload["max_tokens"] = payload["max_completion_tokens"]
|
||||
del payload["max_completion_tokens"]
|
||||
|
||||
# TODO: check if below is needed
|
||||
# if "max_tokens" in payload and "max_completion_tokens" in payload:
|
||||
# del payload["max_tokens"]
|
||||
if "max_tokens" in payload and "max_completion_tokens" in payload:
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
payload = json.dumps(payload)
|
||||
|
||||
@@ -124,18 +124,14 @@ def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers={"Authorization": f"Bearer {key}"},
|
||||
json={
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
},
|
||||
"body": data,
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
payload = data
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
@@ -7,7 +7,7 @@ import shutil
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, Sequence, Union
|
||||
from typing import Iterator, List, Optional, Sequence, Union
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
@@ -28,7 +28,7 @@ import tiktoken
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from open_webui.models.files import Files
|
||||
from open_webui.models.files import FileModel, Files
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
from open_webui.storage.provider import Storage
|
||||
|
||||
@@ -347,6 +347,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"status": True,
|
||||
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"content_extraction": {
|
||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
@@ -369,6 +370,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"search": {
|
||||
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||
"searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
|
||||
"google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||||
@@ -445,6 +447,7 @@ class WebConfig(BaseModel):
|
||||
|
||||
class ConfigUpdateForm(BaseModel):
|
||||
pdf_extract_images: Optional[bool] = None
|
||||
enable_google_drive_integration: Optional[bool] = None
|
||||
file: Optional[FileConfig] = None
|
||||
content_extraction: Optional[ContentExtractionConfig] = None
|
||||
chunk: Optional[ChunkParamUpdateForm] = None
|
||||
@@ -462,6 +465,12 @@ async def update_rag_config(
|
||||
else request.app.state.config.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
|
||||
form_data.enable_google_drive_integration
|
||||
if form_data.enable_google_drive_integration is not None
|
||||
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
)
|
||||
|
||||
if form_data.file is not None:
|
||||
request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
|
||||
request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
|
||||
@@ -1247,21 +1256,22 @@ def process_web_search(
|
||||
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
||||
)
|
||||
|
||||
log.debug(f"web_results: {web_results}")
|
||||
|
||||
try:
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name == "":
|
||||
if collection_name == "" or collection_name is None:
|
||||
collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[
|
||||
:63
|
||||
]
|
||||
|
||||
urls = [result.link for result in web_results]
|
||||
loader = get_web_loader(
|
||||
urls=urls,
|
||||
urls,
|
||||
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
)
|
||||
docs = loader.aload()
|
||||
|
||||
docs = loader.load()
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
|
||||
return {
|
||||
@@ -1428,3 +1438,94 @@ if ENV == "dev":
|
||||
@router.get("/ef/{text}")
|
||||
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
|
||||
return {"result": request.app.state.EMBEDDING_FUNCTION(text)}
|
||||
|
||||
|
||||
class BatchProcessFilesForm(BaseModel):
|
||||
files: List[FileModel]
|
||||
collection_name: str
|
||||
|
||||
|
||||
class BatchProcessFilesResult(BaseModel):
|
||||
file_id: str
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class BatchProcessFilesResponse(BaseModel):
|
||||
results: List[BatchProcessFilesResult]
|
||||
errors: List[BatchProcessFilesResult]
|
||||
|
||||
|
||||
@router.post("/process/files/batch")
|
||||
def process_files_batch(
|
||||
request: Request,
|
||||
form_data: BatchProcessFilesForm,
|
||||
user=Depends(get_verified_user),
|
||||
) -> BatchProcessFilesResponse:
|
||||
"""
|
||||
Process a batch of files and save them to the vector database.
|
||||
"""
|
||||
results: List[BatchProcessFilesResult] = []
|
||||
errors: List[BatchProcessFilesResult] = []
|
||||
collection_name = form_data.collection_name
|
||||
|
||||
# Prepare all documents first
|
||||
all_docs: List[Document] = []
|
||||
for file in form_data.files:
|
||||
try:
|
||||
text_content = file.data.get("content", "")
|
||||
|
||||
docs: List[Document] = [
|
||||
Document(
|
||||
page_content=text_content.replace("<br/>", "\n"),
|
||||
metadata={
|
||||
**file.meta,
|
||||
"name": file.filename,
|
||||
"created_by": file.user_id,
|
||||
"file_id": file.id,
|
||||
"source": file.filename,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
hash = calculate_sha256_string(text_content)
|
||||
Files.update_file_hash_by_id(file.id, hash)
|
||||
Files.update_file_data_by_id(file.id, {"content": text_content})
|
||||
|
||||
all_docs.extend(docs)
|
||||
results.append(BatchProcessFilesResult(file_id=file.id, status="prepared"))
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
|
||||
errors.append(
|
||||
BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
|
||||
)
|
||||
|
||||
# Save all documents in one batch
|
||||
if all_docs:
|
||||
try:
|
||||
save_docs_to_vector_db(
|
||||
request=request,
|
||||
docs=all_docs,
|
||||
collection_name=collection_name,
|
||||
add=True,
|
||||
)
|
||||
|
||||
# Update all files with collection name
|
||||
for result in results:
|
||||
Files.update_file_metadata_by_id(
|
||||
result.file_id, {"collection_name": collection_name}
|
||||
)
|
||||
result.status = "completed"
|
||||
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"process_files_batch: Error saving documents to vector DB: {str(e)}"
|
||||
)
|
||||
for result in results:
|
||||
result.status = "failed"
|
||||
errors.append(
|
||||
BatchProcessFilesResult(file_id=result.file_id, error=str(e))
|
||||
)
|
||||
|
||||
return BatchProcessFilesResponse(results=results, errors=errors)
|
||||
|
||||
@@ -186,9 +186,10 @@ async def generate_title(
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error("Exception occurred", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
content={"detail": "An internal error has occurred."},
|
||||
)
|
||||
|
||||
|
||||
@@ -248,9 +249,10 @@ async def generate_chat_tags(
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error(f"Error generating chat completion: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": "An internal error has occurred."},
|
||||
)
|
||||
|
||||
|
||||
@@ -393,9 +395,10 @@ async def generate_autocompletion(
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error(f"Error generating chat completion: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": "An internal error has occurred."},
|
||||
)
|
||||
|
||||
|
||||
@@ -496,8 +499,8 @@ async def generate_moa_response(
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": form_data.get("stream", False),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"task": str(TASKS.MOA_RESPONSE_GENERATION),
|
||||
"task_body": form_data,
|
||||
},
|
||||
|
||||
@@ -10,6 +10,9 @@ from open_webui.models.users import (
|
||||
UserSettings,
|
||||
UserUpdateForm,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.socket.main import get_active_status_by_user_id
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
@@ -27,7 +30,11 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=list[UserModel])
|
||||
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
|
||||
async def get_users(
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
return Users.get_users(skip, limit)
|
||||
|
||||
|
||||
@@ -192,6 +199,7 @@ async def update_user_info_by_session_user(
|
||||
class UserResponse(BaseModel):
|
||||
name: str
|
||||
profile_image_url: str
|
||||
active: Optional[bool] = None
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
@@ -212,7 +220,13 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user_id)
|
||||
|
||||
if user:
|
||||
return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
|
||||
return UserResponse(
|
||||
**{
|
||||
"name": user.name,
|
||||
"profile_image_url": user.profile_image_url,
|
||||
"active": get_active_status_by_user_id(user_id),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
||||
@@ -4,14 +4,17 @@ import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.users import Users, UserNameResponse
|
||||
from open_webui.models.channels import Channels
|
||||
from open_webui.models.chats import Chats
|
||||
|
||||
from open_webui.env import (
|
||||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
WEBSOCKET_MANAGER,
|
||||
WEBSOCKET_REDIS_URL,
|
||||
)
|
||||
from open_webui.utils.auth import decode_token
|
||||
from open_webui.socket.utils import RedisDict
|
||||
from open_webui.socket.utils import RedisDict, RedisLock
|
||||
|
||||
from open_webui.env import (
|
||||
GLOBAL_LOG_LEVEL,
|
||||
@@ -29,9 +32,7 @@ if WEBSOCKET_MANAGER == "redis":
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
transports=(
|
||||
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
|
||||
),
|
||||
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||
always_connect=True,
|
||||
client_manager=mgr,
|
||||
@@ -40,54 +41,77 @@ else:
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
transports=(
|
||||
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
|
||||
),
|
||||
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||
always_connect=True,
|
||||
)
|
||||
|
||||
|
||||
# Timeout duration in seconds
|
||||
TIMEOUT_DURATION = 3
|
||||
|
||||
# Dictionary to maintain the user pool
|
||||
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
log.debug("Using Redis to manage websockets.")
|
||||
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
|
||||
clean_up_lock = RedisLock(
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
lock_name="usage_cleanup_lock",
|
||||
timeout_secs=TIMEOUT_DURATION * 2,
|
||||
)
|
||||
aquire_func = clean_up_lock.aquire_lock
|
||||
renew_func = clean_up_lock.renew_lock
|
||||
release_func = clean_up_lock.release_lock
|
||||
else:
|
||||
SESSION_POOL = {}
|
||||
USER_POOL = {}
|
||||
USAGE_POOL = {}
|
||||
|
||||
|
||||
# Timeout duration in seconds
|
||||
TIMEOUT_DURATION = 3
|
||||
aquire_func = release_func = renew_func = lambda: True
|
||||
|
||||
|
||||
async def periodic_usage_pool_cleanup():
|
||||
while True:
|
||||
now = int(time.time())
|
||||
for model_id, connections in list(USAGE_POOL.items()):
|
||||
# Creating a list of sids to remove if they have timed out
|
||||
expired_sids = [
|
||||
sid
|
||||
for sid, details in connections.items()
|
||||
if now - details["updated_at"] > TIMEOUT_DURATION
|
||||
]
|
||||
if not aquire_func():
|
||||
log.debug("Usage pool cleanup lock already exists. Not running it.")
|
||||
return
|
||||
log.debug("Running periodic_usage_pool_cleanup")
|
||||
try:
|
||||
while True:
|
||||
if not renew_func():
|
||||
log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
|
||||
raise Exception("Unable to renew usage pool cleanup lock.")
|
||||
|
||||
for sid in expired_sids:
|
||||
del connections[sid]
|
||||
now = int(time.time())
|
||||
send_usage = False
|
||||
for model_id, connections in list(USAGE_POOL.items()):
|
||||
# Creating a list of sids to remove if they have timed out
|
||||
expired_sids = [
|
||||
sid
|
||||
for sid, details in connections.items()
|
||||
if now - details["updated_at"] > TIMEOUT_DURATION
|
||||
]
|
||||
|
||||
if not connections:
|
||||
log.debug(f"Cleaning up model {model_id} from usage pool")
|
||||
del USAGE_POOL[model_id]
|
||||
else:
|
||||
USAGE_POOL[model_id] = connections
|
||||
for sid in expired_sids:
|
||||
del connections[sid]
|
||||
|
||||
# Emit updated usage information after cleaning
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
if not connections:
|
||||
log.debug(f"Cleaning up model {model_id} from usage pool")
|
||||
del USAGE_POOL[model_id]
|
||||
else:
|
||||
USAGE_POOL[model_id] = connections
|
||||
|
||||
await asyncio.sleep(TIMEOUT_DURATION)
|
||||
send_usage = True
|
||||
|
||||
if send_usage:
|
||||
# Emit updated usage information after cleaning
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
|
||||
await asyncio.sleep(TIMEOUT_DURATION)
|
||||
finally:
|
||||
release_func()
|
||||
|
||||
|
||||
app = socketio.ASGIApp(
|
||||
@@ -128,20 +152,19 @@ async def connect(sid, environ, auth):
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
|
||||
if user:
|
||||
SESSION_POOL[sid] = user.id
|
||||
SESSION_POOL[sid] = user.model_dump()
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id].append(sid)
|
||||
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
|
||||
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||
await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
||||
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
|
||||
|
||||
@sio.on("user-join")
|
||||
async def user_join(sid, data):
|
||||
# print("user-join", sid, data)
|
||||
|
||||
auth = data["auth"] if "auth" in data else None
|
||||
if not auth or "token" not in auth:
|
||||
@@ -155,39 +178,91 @@ async def user_join(sid, data):
|
||||
if not user:
|
||||
return
|
||||
|
||||
SESSION_POOL[sid] = user.id
|
||||
SESSION_POOL[sid] = user.model_dump()
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id].append(sid)
|
||||
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
|
||||
# Join all the channels
|
||||
channels = Channels.get_channels_by_user_id(user.id)
|
||||
log.debug(f"{channels=}")
|
||||
for channel in channels:
|
||||
await sio.enter_room(sid, f"channel:{channel.id}")
|
||||
|
||||
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||
|
||||
await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
||||
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
|
||||
return {"id": user.id, "name": user.name}
|
||||
|
||||
|
||||
@sio.on("user-count")
|
||||
async def user_count(sid):
|
||||
await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
||||
@sio.on("join-channels")
|
||||
async def join_channel(sid, data):
|
||||
auth = data["auth"] if "auth" in data else None
|
||||
if not auth or "token" not in auth:
|
||||
return
|
||||
|
||||
data = decode_token(auth["token"])
|
||||
if data is None or "id" not in data:
|
||||
return
|
||||
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
if not user:
|
||||
return
|
||||
|
||||
# Join all the channels
|
||||
channels = Channels.get_channels_by_user_id(user.id)
|
||||
log.debug(f"{channels=}")
|
||||
for channel in channels:
|
||||
await sio.enter_room(sid, f"channel:{channel.id}")
|
||||
|
||||
|
||||
@sio.on("chat")
|
||||
async def chat(sid, data):
|
||||
print("chat", sid, SESSION_POOL[sid], data)
|
||||
@sio.on("channel-events")
|
||||
async def channel_events(sid, data):
|
||||
room = f"channel:{data['channel_id']}"
|
||||
participants = sio.manager.get_participants(
|
||||
namespace="/",
|
||||
room=room,
|
||||
)
|
||||
|
||||
sids = [sid for sid, _ in participants]
|
||||
if sid not in sids:
|
||||
return
|
||||
|
||||
event_data = data["data"]
|
||||
event_type = event_data["type"]
|
||||
|
||||
if event_type == "typing":
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
{
|
||||
"channel_id": data["channel_id"],
|
||||
"message_id": data.get("message_id", None),
|
||||
"data": event_data,
|
||||
"user": UserNameResponse(**SESSION_POOL[sid]).model_dump(),
|
||||
},
|
||||
room=room,
|
||||
)
|
||||
|
||||
|
||||
@sio.on("user-list")
|
||||
async def user_list(sid):
|
||||
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
|
||||
|
||||
|
||||
@sio.event
|
||||
async def disconnect(sid):
|
||||
if sid in SESSION_POOL:
|
||||
user_id = SESSION_POOL[sid]
|
||||
user = SESSION_POOL[sid]
|
||||
del SESSION_POOL[sid]
|
||||
|
||||
user_id = user["id"]
|
||||
USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
|
||||
|
||||
if len(USER_POOL[user_id]) == 0:
|
||||
del USER_POOL[user_id]
|
||||
|
||||
await sio.emit("user-count", {"count": len(USER_POOL)})
|
||||
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
|
||||
else:
|
||||
pass
|
||||
# print(f"Unknown session ID {sid} disconnected")
|
||||
@@ -195,16 +270,57 @@ async def disconnect(sid):
|
||||
|
||||
def get_event_emitter(request_info):
|
||||
async def __event_emitter__(event_data):
|
||||
await sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info["chat_id"],
|
||||
"message_id": request_info["message_id"],
|
||||
"data": event_data,
|
||||
},
|
||||
to=request_info["session_id"],
|
||||
user_id = request_info["user_id"]
|
||||
session_ids = list(
|
||||
set(USER_POOL.get(user_id, []) + [request_info["session_id"]])
|
||||
)
|
||||
|
||||
for session_id in session_ids:
|
||||
await sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info["chat_id"],
|
||||
"message_id": request_info["message_id"],
|
||||
"data": event_data,
|
||||
},
|
||||
to=session_id,
|
||||
)
|
||||
|
||||
if "type" in event_data and event_data["type"] == "status":
|
||||
Chats.add_message_status_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
event_data.get("data", {}),
|
||||
)
|
||||
|
||||
if "type" in event_data and event_data["type"] == "message":
|
||||
message = Chats.get_message_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
)
|
||||
|
||||
content = message.get("content", "")
|
||||
content += event_data.get("data", {}).get("content", "")
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
if "type" in event_data and event_data["type"] == "replace":
|
||||
content = event_data.get("data", {}).get("content", "")
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
return __event_emitter__
|
||||
|
||||
|
||||
@@ -222,3 +338,30 @@ def get_event_call(request_info):
|
||||
return response
|
||||
|
||||
return __event_call__
|
||||
|
||||
|
||||
def get_user_id_from_session_pool(sid):
|
||||
user = SESSION_POOL.get(sid)
|
||||
if user:
|
||||
return user["id"]
|
||||
return None
|
||||
|
||||
|
||||
def get_user_ids_from_room(room):
|
||||
active_session_ids = sio.manager.get_participants(
|
||||
namespace="/",
|
||||
room=room,
|
||||
)
|
||||
|
||||
active_user_ids = list(
|
||||
set(
|
||||
[SESSION_POOL.get(session_id[0])["id"] for session_id in active_session_ids]
|
||||
)
|
||||
)
|
||||
return active_user_ids
|
||||
|
||||
|
||||
def get_active_status_by_user_id(user_id):
|
||||
if user_id in USER_POOL:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -1,5 +1,33 @@
|
||||
import json
|
||||
import redis
|
||||
import uuid
|
||||
|
||||
|
||||
class RedisLock:
|
||||
def __init__(self, redis_url, lock_name, timeout_secs):
|
||||
self.lock_name = lock_name
|
||||
self.lock_id = str(uuid.uuid4())
|
||||
self.timeout_secs = timeout_secs
|
||||
self.lock_obtained = False
|
||||
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
|
||||
|
||||
def aquire_lock(self):
|
||||
# nx=True will only set this key if it _hasn't_ already been set
|
||||
self.lock_obtained = self.redis.set(
|
||||
self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
|
||||
)
|
||||
return self.lock_obtained
|
||||
|
||||
def renew_lock(self):
|
||||
# xx=True will only set this key if it _has_ already been set
|
||||
return self.redis.set(
|
||||
self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
|
||||
)
|
||||
|
||||
def release_lock(self):
|
||||
lock_value = self.redis.get(self.lock_name)
|
||||
if lock_value and lock_value.decode("utf-8") == self.lock_id:
|
||||
self.redis.delete(self.lock_name)
|
||||
|
||||
|
||||
class RedisDict:
|
||||
|
||||
@@ -26,8 +26,8 @@
|
||||
|
||||
html {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'NotoSans', 'NotoSansJP', 'NotoSansKR',
|
||||
'NotoSansSC', 'Twemoji', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto,
|
||||
'Helvetica Neue', Arial, sans-serif;
|
||||
'NotoSansSC', 'Twemoji', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium',
|
||||
Roboto, 'Helvetica Neue', Arial, sans-serif;
|
||||
font-size: 14px; /* Default font size */
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
BIN
backend/open_webui/static/swagger-ui/favicon.png
Normal file
BIN
backend/open_webui/static/swagger-ui/favicon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.9 KiB |
72298
backend/open_webui/static/swagger-ui/swagger-ui-bundle.js
Normal file
72298
backend/open_webui/static/swagger-ui/swagger-ui-bundle.js
Normal file
File diff suppressed because one or more lines are too long
9312
backend/open_webui/static/swagger-ui/swagger-ui.css
Normal file
9312
backend/open_webui/static/swagger-ui/swagger-ui.css
Normal file
File diff suppressed because it is too large
Load Diff
@@ -147,8 +147,10 @@ class StorageProvider:
|
||||
return self._get_file_from_s3(file_path)
|
||||
return self._get_file_from_local(file_path)
|
||||
|
||||
def delete_file(self, filename: str) -> None:
|
||||
def delete_file(self, file_path: str) -> None:
|
||||
"""Deletes a file either from S3 or the local file system."""
|
||||
filename = file_path.split("/")[-1]
|
||||
|
||||
if self.storage_provider == "s3":
|
||||
self._delete_from_s3(filename)
|
||||
|
||||
|
||||
61
backend/open_webui/tasks.py
Normal file
61
backend/open_webui/tasks.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# tasks.py
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from uuid import uuid4
|
||||
|
||||
# A dictionary to keep track of active tasks
|
||||
tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
def cleanup_task(task_id: str):
|
||||
"""
|
||||
Remove a completed or canceled task from the global `tasks` dictionary.
|
||||
"""
|
||||
tasks.pop(task_id, None) # Remove the task if it exists
|
||||
|
||||
|
||||
def create_task(coroutine):
|
||||
"""
|
||||
Create a new asyncio task and add it to the global task dictionary.
|
||||
"""
|
||||
task_id = str(uuid4()) # Generate a unique ID for the task
|
||||
task = asyncio.create_task(coroutine) # Create the task
|
||||
|
||||
# Add a done callback for cleanup
|
||||
task.add_done_callback(lambda t: cleanup_task(task_id))
|
||||
|
||||
tasks[task_id] = task
|
||||
return task_id, task
|
||||
|
||||
|
||||
def get_task(task_id: str):
|
||||
"""
|
||||
Retrieve a task by its task ID.
|
||||
"""
|
||||
return tasks.get(task_id)
|
||||
|
||||
|
||||
def list_tasks():
|
||||
"""
|
||||
List all currently active task IDs.
|
||||
"""
|
||||
return list(tasks.keys())
|
||||
|
||||
|
||||
async def stop_task(task_id: str):
|
||||
"""
|
||||
Cancel a running task and remove it from the global task list.
|
||||
"""
|
||||
task = tasks.get(task_id)
|
||||
if not task:
|
||||
raise ValueError(f"Task with ID {task_id} not found.")
|
||||
|
||||
task.cancel() # Request task cancellation
|
||||
try:
|
||||
await task # Wait for the task to handle the cancellation
|
||||
except asyncio.CancelledError:
|
||||
# Task successfully canceled
|
||||
tasks.pop(task_id, None) # Remove it from the dictionary
|
||||
return {"status": True, "message": f"Task {task_id} successfully stopped."}
|
||||
|
||||
return {"status": False, "message": f"Failed to stop task {task_id}."}
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional, Union, List, Dict, Any
|
||||
from open_webui.models.users import Users, UserModel
|
||||
from open_webui.models.groups import Groups
|
||||
import json
|
||||
|
||||
@@ -93,3 +94,24 @@ def has_access(
|
||||
return user_id in permitted_user_ids or any(
|
||||
group_id in permitted_group_ids for group_id in user_group_ids
|
||||
)
|
||||
|
||||
|
||||
# Get all users with access to a resource
|
||||
def get_users_with_access(
|
||||
type: str = "write", access_control: Optional[dict] = None
|
||||
) -> List[UserModel]:
|
||||
if access_control is None:
|
||||
return Users.get_users()
|
||||
|
||||
permission_access = access_control.get(type, {})
|
||||
permitted_group_ids = permission_access.get("group_ids", [])
|
||||
permitted_user_ids = permission_access.get("user_ids", [])
|
||||
|
||||
user_ids_with_access = set(permitted_user_ids)
|
||||
|
||||
for group_id in permitted_group_ids:
|
||||
group_user_ids = Groups.get_group_user_ids_by_id(group_id)
|
||||
if group_user_ids:
|
||||
user_ids_with_access.update(group_user_ids)
|
||||
|
||||
return Users.get_users_by_user_ids(list(user_ids_with_access))
|
||||
|
||||
@@ -95,6 +95,20 @@ def get_current_user(
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||
)
|
||||
|
||||
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
|
||||
allowed_paths = [
|
||||
path.strip()
|
||||
for path in str(
|
||||
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
|
||||
).split(",")
|
||||
]
|
||||
|
||||
if request.url.path not in allowed_paths:
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||
)
|
||||
|
||||
return get_current_user_by_api_key(token)
|
||||
|
||||
# auth by jwt token
|
||||
|
||||
@@ -89,7 +89,7 @@ async def generate_chat_completion(
|
||||
if model_ids and filter_mode == "exclude":
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in await get_all_models(request)
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||
]
|
||||
|
||||
@@ -99,7 +99,7 @@ async def generate_chat_completion(
|
||||
else:
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in await get_all_models(request)
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena"
|
||||
]
|
||||
selected_model_id = random.choice(model_ids)
|
||||
@@ -114,21 +114,27 @@ async def generate_chat_completion(
|
||||
yield chunk
|
||||
|
||||
response = await generate_chat_completion(
|
||||
form_data, user, bypass_filter=True
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator), media_type="text/event-stream"
|
||||
stream_wrapper(response.body_iterator),
|
||||
media_type="text/event-stream",
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
**(await generate_chat_completion(form_data, user, bypass_filter=True)),
|
||||
**(
|
||||
await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
),
|
||||
"selected_model_id": selected_model_id,
|
||||
}
|
||||
|
||||
if model.get("pipe"):
|
||||
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||
return await generate_function_chat_completion(
|
||||
form_data, user=user, models=models
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
if model["owned_by"] == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
@@ -141,6 +147,7 @@ async def generate_chat_completion(
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
@@ -150,8 +157,12 @@ async def generate_chat_completion(
|
||||
)
|
||||
|
||||
|
||||
chat_completion = generate_chat_completion
|
||||
|
||||
|
||||
async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
await get_all_models(request)
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
@@ -171,6 +182,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -179,6 +191,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -286,7 +299,8 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
if not action:
|
||||
raise Exception(f"Action not found: {action_id}")
|
||||
|
||||
await get_all_models(request)
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
@@ -301,6 +315,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
__event_call__ = get_event_call(
|
||||
@@ -308,6 +323,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -16,14 +16,16 @@ log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
|
||||
default_headers = {"User-Agent": "Mozilla/5.0"}
|
||||
|
||||
|
||||
def queue_prompt(prompt, client_id, base_url):
|
||||
def queue_prompt(prompt, client_id, base_url, api_key):
|
||||
log.info("queue_prompt")
|
||||
p = {"prompt": prompt, "client_id": client_id}
|
||||
data = json.dumps(p).encode("utf-8")
|
||||
log.debug(f"queue_prompt data: {data}")
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"{base_url}/prompt", data=data, headers=default_headers
|
||||
f"{base_url}/prompt",
|
||||
data=data,
|
||||
headers={**default_headers, "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
response = urllib.request.urlopen(req).read()
|
||||
return json.loads(response)
|
||||
@@ -32,12 +34,13 @@ def queue_prompt(prompt, client_id, base_url):
|
||||
raise e
|
||||
|
||||
|
||||
def get_image(filename, subfolder, folder_type, base_url):
|
||||
def get_image(filename, subfolder, folder_type, base_url, api_key):
|
||||
log.info("get_image")
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
req = urllib.request.Request(
|
||||
f"{base_url}/view?{url_values}", headers=default_headers
|
||||
f"{base_url}/view?{url_values}",
|
||||
headers={**default_headers, "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
with urllib.request.urlopen(req) as response:
|
||||
return response.read()
|
||||
@@ -50,18 +53,19 @@ def get_image_url(filename, subfolder, folder_type, base_url):
|
||||
return f"{base_url}/view?{url_values}"
|
||||
|
||||
|
||||
def get_history(prompt_id, base_url):
|
||||
def get_history(prompt_id, base_url, api_key):
|
||||
log.info("get_history")
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{base_url}/history/{prompt_id}", headers=default_headers
|
||||
f"{base_url}/history/{prompt_id}",
|
||||
headers={**default_headers, "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
with urllib.request.urlopen(req) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
|
||||
def get_images(ws, prompt, client_id, base_url):
|
||||
prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
|
||||
def get_images(ws, prompt, client_id, base_url, api_key):
|
||||
prompt_id = queue_prompt(prompt, client_id, base_url, api_key)["prompt_id"]
|
||||
output_images = []
|
||||
while True:
|
||||
out = ws.recv()
|
||||
@@ -74,7 +78,7 @@ def get_images(ws, prompt, client_id, base_url):
|
||||
else:
|
||||
continue # previews are binary data
|
||||
|
||||
history = get_history(prompt_id, base_url)[prompt_id]
|
||||
history = get_history(prompt_id, base_url, api_key)[prompt_id]
|
||||
for o in history["outputs"]:
|
||||
for node_id in history["outputs"]:
|
||||
node_output = history["outputs"][node_id]
|
||||
@@ -113,7 +117,7 @@ class ComfyUIGenerateImageForm(BaseModel):
|
||||
|
||||
|
||||
async def comfyui_generate_image(
|
||||
model: str, payload: ComfyUIGenerateImageForm, client_id, base_url
|
||||
model: str, payload: ComfyUIGenerateImageForm, client_id, base_url, api_key
|
||||
):
|
||||
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
workflow = json.loads(payload.workflow.workflow)
|
||||
@@ -167,7 +171,8 @@ async def comfyui_generate_image(
|
||||
|
||||
try:
|
||||
ws = websocket.WebSocket()
|
||||
ws.connect(f"{ws_url}/ws?clientId={client_id}")
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
ws.connect(f"{ws_url}/ws?clientId={client_id}", header=headers)
|
||||
log.info("WebSocket connection established.")
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to connect to WebSocket server: {e}")
|
||||
@@ -176,7 +181,9 @@ async def comfyui_generate_image(
|
||||
try:
|
||||
log.info("Sending workflow to WebSocket server.")
|
||||
log.info(f"Workflow: {workflow}")
|
||||
images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url)
|
||||
images = await asyncio.to_thread(
|
||||
get_images, ws, workflow, client_id, base_url, api_key
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error while receiving images: {e}")
|
||||
images = None
|
||||
|
||||
@@ -2,21 +2,36 @@ import time
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import asyncio
|
||||
from aiocache import cached
|
||||
from typing import Any, Optional
|
||||
import random
|
||||
import json
|
||||
import inspect
|
||||
from uuid import uuid4
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.socket.main import (
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
get_active_status_by_user_id,
|
||||
)
|
||||
from open_webui.routers.tasks import generate_queries
|
||||
from open_webui.routers.tasks import (
|
||||
generate_queries,
|
||||
generate_title,
|
||||
generate_chat_tags,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
@@ -33,16 +48,25 @@ from open_webui.utils.task import (
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
get_message_list,
|
||||
add_or_update_system_message,
|
||||
get_last_user_message,
|
||||
get_last_assistant_message,
|
||||
prepend_to_first_user_message_content,
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
|
||||
|
||||
from open_webui.tasks import create_task
|
||||
|
||||
from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
ENABLE_REALTIME_CHAT_SAVE,
|
||||
)
|
||||
from open_webui.constants import TASKS
|
||||
|
||||
|
||||
@@ -312,6 +336,156 @@ async def chat_completion_tools_handler(
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
async def chat_web_search_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
event_emitter = extra_params["__event_emitter__"]
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "Generating search query",
|
||||
"done": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
messages = form_data["messages"]
|
||||
user_message = get_last_user_message(messages)
|
||||
|
||||
queries = []
|
||||
try:
|
||||
res = await generate_queries(
|
||||
request,
|
||||
{
|
||||
"model": form_data["model"],
|
||||
"messages": messages,
|
||||
"prompt": user_message,
|
||||
"type": "web_search",
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
response = res["choices"][0]["message"]["content"]
|
||||
|
||||
try:
|
||||
bracket_start = response.find("{")
|
||||
bracket_end = response.rfind("}") + 1
|
||||
|
||||
if bracket_start == -1 or bracket_end == -1:
|
||||
raise Exception("No JSON object found in the response")
|
||||
|
||||
response = response[bracket_start:bracket_end]
|
||||
queries = json.loads(response)
|
||||
queries = queries.get("queries", [])
|
||||
except Exception as e:
|
||||
queries = [response]
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
queries = [user_message]
|
||||
|
||||
if len(queries) == 0:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "No search query generated",
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
searchQuery = queries[0]
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": 'Searching "{{searchQuery}}"',
|
||||
"query": searchQuery,
|
||||
"done": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
# Offload process_web_search to a separate thread
|
||||
loop = asyncio.get_running_loop()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = await loop.run_in_executor(
|
||||
executor,
|
||||
lambda: process_web_search(
|
||||
request,
|
||||
SearchForm(
|
||||
**{
|
||||
"query": searchQuery,
|
||||
}
|
||||
),
|
||||
user,
|
||||
),
|
||||
)
|
||||
|
||||
if results:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "Searched {{count}} sites",
|
||||
"query": searchQuery,
|
||||
"urls": results["filenames"],
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
files = form_data.get("files", [])
|
||||
files.append(
|
||||
{
|
||||
"collection_name": results["collection_name"],
|
||||
"name": searchQuery,
|
||||
"type": "web_search_results",
|
||||
"urls": results["filenames"],
|
||||
}
|
||||
)
|
||||
form_data["files"] = files
|
||||
else:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "No search results found",
|
||||
"query": searchQuery,
|
||||
"done": True,
|
||||
"error": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": 'Error searching "{{searchQuery}}"',
|
||||
"query": searchQuery,
|
||||
"done": True,
|
||||
"error": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
async def chat_completion_files_handler(
|
||||
request: Request, body: dict, user: UserModel
|
||||
) -> tuple[dict, dict[str, list]]:
|
||||
@@ -320,6 +494,7 @@ async def chat_completion_files_handler(
|
||||
if files := body.get("metadata", {}).get("files", None):
|
||||
try:
|
||||
queries_response = await generate_queries(
|
||||
request,
|
||||
{
|
||||
"model": body["model"],
|
||||
"messages": body["messages"],
|
||||
@@ -362,19 +537,44 @@ async def chat_completion_files_handler(
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
async def process_chat_payload(request, form_data, user, model):
|
||||
metadata = {
|
||||
"chat_id": form_data.pop("chat_id", None),
|
||||
"message_id": form_data.pop("id", None),
|
||||
"session_id": form_data.pop("session_id", None),
|
||||
"tool_ids": form_data.get("tool_ids", None),
|
||||
"files": form_data.get("files", None),
|
||||
}
|
||||
form_data["metadata"] = metadata
|
||||
def apply_params_to_form_data(form_data, model):
|
||||
params = form_data.pop("params", {})
|
||||
if model.get("ollama"):
|
||||
form_data["options"] = params
|
||||
|
||||
if "format" in params:
|
||||
form_data["format"] = params["format"]
|
||||
|
||||
if "keep_alive" in params:
|
||||
form_data["keep_alive"] = params["keep_alive"]
|
||||
else:
|
||||
if "seed" in params:
|
||||
form_data["seed"] = params["seed"]
|
||||
|
||||
if "stop" in params:
|
||||
form_data["stop"] = params["stop"]
|
||||
|
||||
if "temperature" in params:
|
||||
form_data["temperature"] = params["temperature"]
|
||||
|
||||
if "top_p" in params:
|
||||
form_data["top_p"] = params["top_p"]
|
||||
|
||||
if "frequency_penalty" in params:
|
||||
form_data["frequency_penalty"] = params["frequency_penalty"]
|
||||
return form_data
|
||||
|
||||
|
||||
async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
form_data = apply_params_to_form_data(form_data, model)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_call = get_event_call(metadata)
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": get_event_emitter(metadata),
|
||||
"__event_call__": get_event_call(metadata),
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_call,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
@@ -388,18 +588,70 @@ async def process_chat_payload(request, form_data, user, model):
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
models = request.app.state.MODELS
|
||||
|
||||
events = []
|
||||
sources = []
|
||||
|
||||
user_message = get_last_user_message(form_data["messages"])
|
||||
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
|
||||
|
||||
if model_knowledge:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "knowledge_search",
|
||||
"query": user_message,
|
||||
"done": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
knowledge_files = []
|
||||
for item in model_knowledge:
|
||||
if item.get("collection_name"):
|
||||
knowledge_files.append(
|
||||
{
|
||||
"id": item.get("collection_name"),
|
||||
"name": item.get("name"),
|
||||
"legacy": True,
|
||||
}
|
||||
)
|
||||
elif item.get("collection_names"):
|
||||
knowledge_files.append(
|
||||
{
|
||||
"name": item.get("name"),
|
||||
"type": "collection",
|
||||
"collection_names": item.get("collection_names"),
|
||||
"legacy": True,
|
||||
}
|
||||
)
|
||||
else:
|
||||
knowledge_files.append(item)
|
||||
|
||||
files = form_data.get("files", [])
|
||||
files.extend(knowledge_files)
|
||||
form_data["files"] = files
|
||||
|
||||
features = form_data.pop("features", None)
|
||||
if features:
|
||||
if "web_search" in features and features["web_search"]:
|
||||
form_data = await chat_web_search_handler(
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_filter_functions_handler(
|
||||
request, form_data, model, extra_params
|
||||
)
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
raise Exception(f"Error: {e}")
|
||||
|
||||
tool_ids = form_data.pop("tool_ids", None)
|
||||
files = form_data.pop("files", None)
|
||||
# Remove files duplicates
|
||||
if files:
|
||||
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
|
||||
|
||||
metadata = {
|
||||
**metadata,
|
||||
@@ -478,31 +730,366 @@ async def process_chat_payload(request, form_data, user, model):
|
||||
if len(sources) > 0:
|
||||
events.append({"sources": sources})
|
||||
|
||||
if model_knowledge:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "knowledge_search",
|
||||
"query": user_message,
|
||||
"done": True,
|
||||
"hidden": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return form_data, events
|
||||
|
||||
|
||||
async def process_chat_response(response, events):
|
||||
async def process_chat_response(
|
||||
request, response, form_data, user, events, metadata, tasks
|
||||
):
|
||||
async def background_tasks_handler():
|
||||
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
|
||||
message = message_map.get(metadata["message_id"]) if message_map else None
|
||||
|
||||
if message:
|
||||
messages = get_message_list(message_map, message.get("id"))
|
||||
|
||||
if tasks:
|
||||
if TASKS.TITLE_GENERATION in tasks:
|
||||
if tasks[TASKS.TITLE_GENERATION]:
|
||||
res = await generate_title(
|
||||
request,
|
||||
{
|
||||
"model": message["model"],
|
||||
"messages": messages,
|
||||
"chat_id": metadata["chat_id"],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
if res and isinstance(res, dict):
|
||||
title = (
|
||||
res.get("choices", [])[0]
|
||||
.get("message", {})
|
||||
.get(
|
||||
"content",
|
||||
message.get("content", "New Chat"),
|
||||
)
|
||||
).strip()
|
||||
|
||||
if not title:
|
||||
title = messages[0].get("content", "New Chat")
|
||||
|
||||
Chats.update_chat_title_by_id(metadata["chat_id"], title)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:title",
|
||||
"data": title,
|
||||
}
|
||||
)
|
||||
elif len(messages) == 2:
|
||||
title = messages[0].get("content", "New Chat")
|
||||
|
||||
Chats.update_chat_title_by_id(metadata["chat_id"], title)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:title",
|
||||
"data": message.get("content", "New Chat"),
|
||||
}
|
||||
)
|
||||
|
||||
if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
|
||||
res = await generate_chat_tags(
|
||||
request,
|
||||
{
|
||||
"model": message["model"],
|
||||
"messages": messages,
|
||||
"chat_id": metadata["chat_id"],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
if res and isinstance(res, dict):
|
||||
tags_string = (
|
||||
res.get("choices", [])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
)
|
||||
|
||||
tags_string = tags_string[
|
||||
tags_string.find("{") : tags_string.rfind("}") + 1
|
||||
]
|
||||
|
||||
try:
|
||||
tags = json.loads(tags_string).get("tags", [])
|
||||
Chats.update_chat_tags_by_id(
|
||||
metadata["chat_id"], tags, user
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:tags",
|
||||
"data": tags,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
event_emitter = None
|
||||
if (
|
||||
"session_id" in metadata
|
||||
and metadata["session_id"]
|
||||
and "chat_id" in metadata
|
||||
and metadata["chat_id"]
|
||||
and "message_id" in metadata
|
||||
and metadata["message_id"]
|
||||
):
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
|
||||
if not isinstance(response, StreamingResponse):
|
||||
if event_emitter:
|
||||
|
||||
if "selected_model_id" in response:
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"selectedModelId": response["selected_model_id"],
|
||||
},
|
||||
)
|
||||
|
||||
if response.get("choices", [])[0].get("message", {}).get("content"):
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
|
||||
if content:
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": response,
|
||||
}
|
||||
)
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": {
|
||||
"done": True,
|
||||
"content": content,
|
||||
"title": title,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if get_active_status_by_user_id(user.id) is None:
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
webhook_url,
|
||||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||||
{
|
||||
"action": "chat",
|
||||
"message": content,
|
||||
"title": title,
|
||||
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
|
||||
},
|
||||
)
|
||||
|
||||
await background_tasks_handler()
|
||||
|
||||
return response
|
||||
else:
|
||||
return response
|
||||
|
||||
if not any(
|
||||
content_type in response.headers["Content-Type"]
|
||||
for content_type in ["text/event-stream", "application/x-ndjson"]
|
||||
):
|
||||
return response
|
||||
|
||||
content_type = response.headers["Content-Type"]
|
||||
is_openai = "text/event-stream" in content_type
|
||||
is_ollama = "application/x-ndjson" in content_type
|
||||
if event_emitter:
|
||||
|
||||
if not is_openai and not is_ollama:
|
||||
return response
|
||||
task_id = str(uuid4()) # Create a unique task ID.
|
||||
|
||||
async def stream_wrapper(original_generator, events):
|
||||
def wrap_item(item):
|
||||
return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
||||
# Handle as a background task
|
||||
async def post_response_handler(response, events):
|
||||
message = Chats.get_message_by_id_and_message_id(
|
||||
metadata["chat_id"], metadata["message_id"]
|
||||
)
|
||||
content = message.get("content", "") if message else ""
|
||||
|
||||
for event in events:
|
||||
yield wrap_item(json.dumps(event))
|
||||
try:
|
||||
for event in events:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": event,
|
||||
}
|
||||
)
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
**event,
|
||||
},
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator, events),
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
async for line in response.body_iterator:
|
||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
||||
data = line
|
||||
|
||||
# Skip empty lines
|
||||
if not data.strip():
|
||||
continue
|
||||
|
||||
# "data: " is the prefix for each event
|
||||
if not data.startswith("data: "):
|
||||
continue
|
||||
|
||||
# Remove the prefix
|
||||
data = data[len("data: ") :]
|
||||
|
||||
try:
|
||||
data = json.loads(data)
|
||||
|
||||
if "selected_model_id" in data:
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"selectedModelId": data["selected_model_id"],
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
value = (
|
||||
data.get("choices", [])[0]
|
||||
.get("delta", {})
|
||||
.get("content")
|
||||
)
|
||||
|
||||
if value:
|
||||
content = f"{content}{value}"
|
||||
|
||||
if ENABLE_REALTIME_CHAT_SAVE:
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
else:
|
||||
data = {
|
||||
"content": content,
|
||||
}
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
done = "data: [DONE]" in line
|
||||
|
||||
if done:
|
||||
pass
|
||||
else:
|
||||
continue
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
data = {"done": True, "content": content, "title": title}
|
||||
|
||||
if not ENABLE_REALTIME_CHAT_SAVE:
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if get_active_status_by_user_id(user.id) is None:
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
webhook_url,
|
||||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||||
{
|
||||
"action": "chat",
|
||||
"message": content,
|
||||
"title": title,
|
||||
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
|
||||
},
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
|
||||
await background_tasks_handler()
|
||||
except asyncio.CancelledError:
|
||||
print("Task was cancelled!")
|
||||
await event_emitter({"type": "task-cancelled"})
|
||||
|
||||
if not ENABLE_REALTIME_CHAT_SAVE:
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
if response.background is not None:
|
||||
await response.background()
|
||||
|
||||
# background_tasks.add_task(post_response_handler, response, events)
|
||||
task_id, _ = create_task(post_response_handler(response, events))
|
||||
return {"status": True, "task_id": task_id}
|
||||
|
||||
else:
|
||||
|
||||
# Fallback to the original response
|
||||
async def stream_wrapper(original_generator, events):
|
||||
def wrap_item(item):
|
||||
return f"data: {item}\n\n"
|
||||
|
||||
for event in events:
|
||||
yield wrap_item(json.dumps(event))
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator, events),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,34 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
def get_message_list(messages, message_id):
|
||||
"""
|
||||
Reconstructs a list of messages in order up to the specified message_id.
|
||||
|
||||
:param message_id: ID of the message to reconstruct the chain
|
||||
:param messages: Message history dict containing all messages
|
||||
:return: List of ordered messages starting from the root to the given message
|
||||
"""
|
||||
|
||||
# Find the message by its id
|
||||
current_message = messages.get(message_id)
|
||||
|
||||
if not current_message:
|
||||
return f"Message ID {message_id} not found in the history."
|
||||
|
||||
# Reconstruct the chain by following the parentId links
|
||||
message_list = []
|
||||
|
||||
while current_message:
|
||||
message_list.insert(
|
||||
0, current_message
|
||||
) # Insert the message at the beginning of the list
|
||||
parent_id = current_message["parentId"]
|
||||
current_message = messages.get(parent_id) if parent_id else None
|
||||
|
||||
return message_list
|
||||
|
||||
|
||||
def get_messages_content(messages: list[dict]) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
@@ -40,6 +68,13 @@ def get_last_user_message(messages: list[dict]) -> Optional[str]:
|
||||
return get_content_from_message(message)
|
||||
|
||||
|
||||
def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]:
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "assistant":
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "assistant":
|
||||
|
||||
@@ -58,7 +58,6 @@ async def get_all_base_models(request: Request):
|
||||
return models
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request):
|
||||
models = await get_all_base_models(request)
|
||||
|
||||
|
||||
@@ -14,13 +14,16 @@ from starlette.responses import RedirectResponse
|
||||
|
||||
from open_webui.models.auths import Auths
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm
|
||||
from open_webui.config import (
|
||||
DEFAULT_USER_ROLE,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
||||
OAUTH_PROVIDERS,
|
||||
ENABLE_OAUTH_ROLE_MANAGEMENT,
|
||||
ENABLE_OAUTH_GROUP_MANAGEMENT,
|
||||
OAUTH_ROLES_CLAIM,
|
||||
OAUTH_GROUPS_CLAIM,
|
||||
OAUTH_EMAIL_CLAIM,
|
||||
OAUTH_PICTURE_CLAIM,
|
||||
OAUTH_USERNAME_CLAIM,
|
||||
@@ -44,7 +47,9 @@ auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
|
||||
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
|
||||
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
|
||||
auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT
|
||||
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
|
||||
auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
|
||||
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
|
||||
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
|
||||
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
|
||||
@@ -119,6 +124,61 @@ class OAuthManager:
|
||||
|
||||
return role
|
||||
|
||||
def update_user_groups(self, user, user_data, default_permissions):
|
||||
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
||||
|
||||
user_oauth_groups: list[str] = user_data.get(oauth_claim, list())
|
||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
||||
|
||||
# Remove groups that user is no longer a part of
|
||||
for group_model in user_current_groups:
|
||||
if group_model.name not in user_oauth_groups:
|
||||
# Remove group from user
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids = [i for i in user_ids if i != user.id]
|
||||
|
||||
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
||||
group_permissions = group_model.permissions
|
||||
if not group_permissions:
|
||||
group_permissions = default_permissions
|
||||
|
||||
update_form = GroupUpdateForm(
|
||||
name=group_model.name,
|
||||
description=group_model.description,
|
||||
permissions=group_permissions,
|
||||
user_ids=user_ids,
|
||||
)
|
||||
Groups.update_group_by_id(
|
||||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
)
|
||||
|
||||
# Add user to new groups
|
||||
for group_model in all_available_groups:
|
||||
if group_model.name in user_oauth_groups and not any(
|
||||
gm.name == group_model.name for gm in user_current_groups
|
||||
):
|
||||
# Add user to group
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids.append(user.id)
|
||||
|
||||
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
||||
group_permissions = group_model.permissions
|
||||
if not group_permissions:
|
||||
group_permissions = default_permissions
|
||||
|
||||
update_form = GroupUpdateForm(
|
||||
name=group_model.name,
|
||||
description=group_model.description,
|
||||
permissions=group_permissions,
|
||||
user_ids=user_ids,
|
||||
)
|
||||
Groups.update_group_by_id(
|
||||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
)
|
||||
|
||||
async def handle_login(self, provider, request):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
@@ -254,6 +314,13 @@ class OAuthManager:
|
||||
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
|
||||
)
|
||||
|
||||
if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT:
|
||||
self.update_user_groups(
|
||||
user=user,
|
||||
user_data=user_data,
|
||||
default_permissions=request.app.state.config.USER_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
|
||||
@@ -154,9 +154,16 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
)
|
||||
ollama_payload["stream"] = openai_payload.get("stream", False)
|
||||
|
||||
if "format" in openai_payload:
|
||||
ollama_payload["format"] = openai_payload["format"]
|
||||
|
||||
# If there are advanced parameters in the payload, format them in Ollama's options field
|
||||
ollama_options = {}
|
||||
|
||||
if openai_payload.get("options"):
|
||||
ollama_payload["options"] = openai_payload["options"]
|
||||
ollama_options = openai_payload["options"]
|
||||
|
||||
# Handle parameters which map directly
|
||||
for param in ["temperature", "top_p", "seed"]:
|
||||
if param in openai_payload:
|
||||
|
||||
@@ -29,7 +29,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
(
|
||||
(
|
||||
data.get("eval_count", 0)
|
||||
/ ((data.get("eval_duration", 0) / 1_000_000_000))
|
||||
/ ((data.get("eval_duration", 0) / 10_000_000))
|
||||
)
|
||||
* 100
|
||||
),
|
||||
@@ -43,12 +43,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
(
|
||||
(
|
||||
data.get("prompt_eval_count", 0)
|
||||
/ (
|
||||
(
|
||||
data.get("prompt_eval_duration", 0)
|
||||
/ 1_000_000_000
|
||||
)
|
||||
)
|
||||
/ ((data.get("prompt_eval_duration", 0) / 10_000_000))
|
||||
)
|
||||
* 100
|
||||
),
|
||||
@@ -57,20 +52,12 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
if data.get("prompt_eval_duration", 0) > 0
|
||||
else "N/A"
|
||||
),
|
||||
"total_duration": round(
|
||||
((data.get("total_duration", 0) / 1_000_000) * 100), 2
|
||||
),
|
||||
"load_duration": round(
|
||||
((data.get("load_duration", 0) / 1_000_000) * 100), 2
|
||||
),
|
||||
"total_duration": data.get("total_duration", 0),
|
||||
"load_duration": data.get("load_duration", 0),
|
||||
"prompt_eval_count": data.get("prompt_eval_count", 0),
|
||||
"prompt_eval_duration": round(
|
||||
((data.get("prompt_eval_duration", 0) / 1_000_000) * 100), 2
|
||||
),
|
||||
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
|
||||
"eval_count": data.get("eval_count", 0),
|
||||
"eval_duration": round(
|
||||
((data.get("eval_duration", 0) / 1_000_000) * 100), 2
|
||||
),
|
||||
"eval_duration": data.get("eval_duration", 0),
|
||||
"approximate_total": (
|
||||
lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s"
|
||||
)((data.get("total_duration", 0) or 0) // 1_000_000_000),
|
||||
|
||||
@@ -11,6 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
|
||||
|
||||
def post_webhook(url: str, message: str, event_data: dict) -> bool:
|
||||
try:
|
||||
log.debug(f"post_webhook: {url}, {message}, {event_data}")
|
||||
payload = {}
|
||||
|
||||
# Slack and Google Chat Webhooks
|
||||
@@ -18,7 +19,11 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
|
||||
payload["text"] = message
|
||||
# Discord Webhooks
|
||||
elif "https://discord.com/api/webhooks" in url:
|
||||
payload["content"] = message
|
||||
payload["content"] = (
|
||||
message
|
||||
if len(message) < 2000
|
||||
else f"{message[: 2000 - 20]}... (truncated)"
|
||||
)
|
||||
# Microsoft Teams Webhooks
|
||||
elif "webhook.office.com" in url:
|
||||
action = event_data.get("action", "undefined")
|
||||
|
||||
@@ -3,7 +3,7 @@ uvicorn[standard]==0.30.6
|
||||
pydantic==2.9.2
|
||||
python-multipart==0.0.18
|
||||
|
||||
Flask==3.0.3
|
||||
Flask==3.1.0
|
||||
Flask-Cors==5.0.0
|
||||
|
||||
python-socketio==5.11.3
|
||||
@@ -18,7 +18,7 @@ aiofiles
|
||||
|
||||
sqlalchemy==2.0.32
|
||||
alembic==1.14.0
|
||||
peewee==3.17.6
|
||||
peewee==3.17.8
|
||||
peewee-migrate==1.12.2
|
||||
psycopg2-binary==2.9.9
|
||||
pgvector==0.3.5
|
||||
@@ -55,7 +55,7 @@ einops==0.8.0
|
||||
|
||||
ftfy==6.2.3
|
||||
pypdf==4.3.1
|
||||
fpdf2==2.7.9
|
||||
fpdf2==2.8.2
|
||||
pymdown-extensions==10.11.2
|
||||
docx2txt==0.8
|
||||
python-pptx==1.0.0
|
||||
@@ -67,7 +67,7 @@ pandas==2.2.3
|
||||
openpyxl==3.1.5
|
||||
pyxlsb==1.0.10
|
||||
xlrd==2.0.1
|
||||
validators==0.33.0
|
||||
validators==0.34.0
|
||||
psutil
|
||||
sentencepiece
|
||||
soundfile==0.12.1
|
||||
@@ -78,7 +78,7 @@ rank-bm25==0.2.2
|
||||
|
||||
faster-whisper==1.0.3
|
||||
|
||||
PyJWT[crypto]==2.9.0
|
||||
PyJWT[crypto]==2.10.1
|
||||
authlib==1.3.2
|
||||
|
||||
black==24.8.0
|
||||
@@ -90,6 +90,11 @@ extract_msg
|
||||
pydub
|
||||
duckduckgo-search~=6.3.5
|
||||
|
||||
## Google Drive
|
||||
google-api-python-client
|
||||
google-auth-httplib2
|
||||
google-auth-oauthlib
|
||||
|
||||
## Tests
|
||||
docker~=7.1.0
|
||||
pytest~=8.3.2
|
||||
|
||||
Reference in New Issue
Block a user