Merge branch 'open-webui:main' into main

This commit is contained in:
MadsLang
2025-01-13 08:28:13 +01:00
committed by GitHub
4022 changed files with 103589 additions and 2083 deletions

View File

@@ -21,7 +21,7 @@ from open_webui.env import (
WEBUI_NAME,
log,
DATABASE_URL,
OFFLINE_MODE
OFFLINE_MODE,
)
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func
@@ -272,6 +272,18 @@ ENABLE_API_KEY = PersistentConfig(
os.environ.get("ENABLE_API_KEY", "True").lower() == "true",
)
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = PersistentConfig(
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS",
"auth.api_key.endpoint_restrictions",
os.environ.get("ENABLE_API_KEY_ENDPOINT_RESTRICTIONS", "False").lower() == "true",
)
API_KEY_ALLOWED_ENDPOINTS = PersistentConfig(
"API_KEY_ALLOWED_ENDPOINTS",
"auth.api_key.allowed_endpoints",
os.environ.get("API_KEY_ALLOWED_ENDPOINTS", ""),
)
JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
@@ -307,6 +319,7 @@ GOOGLE_CLIENT_SECRET = PersistentConfig(
os.environ.get("GOOGLE_CLIENT_SECRET", ""),
)
GOOGLE_OAUTH_SCOPE = PersistentConfig(
"GOOGLE_OAUTH_SCOPE",
"oauth.google.scope",
@@ -403,12 +416,24 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
)
OAUTH_GROUPS_CLAIM = PersistentConfig(
"OAUTH_GROUPS_CLAIM",
"oauth.oidc.group_claim",
os.environ.get("OAUTH_GROUP_CLAIM", "groups"),
)
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
"ENABLE_OAUTH_ROLE_MANAGEMENT",
"oauth.enable_role_mapping",
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
)
ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig(
"ENABLE_OAUTH_GROUP_MANAGEMENT",
"oauth.enable_group_mapping",
os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true",
)
OAUTH_ROLES_CLAIM = PersistentConfig(
"OAUTH_ROLES_CLAIM",
"oauth.roles_claim",
@@ -696,6 +721,12 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1"
# WEBUI
####################################
WEBUI_URL = PersistentConfig(
"WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000")
)
ENABLE_SIGNUP = PersistentConfig(
"ENABLE_SIGNUP",
"ui.enable_signup",
@@ -823,6 +854,12 @@ USER_PERMISSIONS = PersistentConfig(
},
)
ENABLE_CHANNELS = PersistentConfig(
"ENABLE_CHANNELS",
"channels.enable",
os.environ.get("ENABLE_CHANNELS", "False").lower() == "true",
)
ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig(
"ENABLE_EVALUATION_ARENA_MODELS",
@@ -1174,11 +1211,34 @@ if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
raise ValueError(
"Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database."
)
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
)
####################################
# Information Retrieval (RAG)
####################################
# If configured, Google Drive will be available as an upload option.
ENABLE_GOOGLE_DRIVE_INTEGRATION = PersistentConfig(
"ENABLE_GOOGLE_DRIVE_INTEGRATION",
"google_drive.enable",
os.getenv("ENABLE_GOOGLE_DRIVE_INTEGRATION", "False").lower() == "true",
)
GOOGLE_DRIVE_CLIENT_ID = PersistentConfig(
"GOOGLE_DRIVE_CLIENT_ID",
"google_drive.client_id",
os.environ.get("GOOGLE_DRIVE_CLIENT_ID", ""),
)
GOOGLE_DRIVE_API_KEY = PersistentConfig(
"GOOGLE_DRIVE_API_KEY",
"google_drive.api_key",
os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
)
# RAG Content Extraction
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
@@ -1253,7 +1313,8 @@ RAG_EMBEDDING_MODEL = PersistentConfig(
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
not OFFLINE_MODE and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
not OFFLINE_MODE
and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
)
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
@@ -1278,7 +1339,8 @@ if RAG_RERANKING_MODEL.value != "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
RAG_RERANKING_MODEL_AUTO_UPDATE = (
not OFFLINE_MODE and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
not OFFLINE_MODE
and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
)
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
@@ -1412,6 +1474,7 @@ RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
],
)
SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url",
@@ -1587,6 +1650,12 @@ COMFYUI_BASE_URL = PersistentConfig(
os.getenv("COMFYUI_BASE_URL", ""),
)
COMFYUI_API_KEY = PersistentConfig(
"COMFYUI_API_KEY",
"image_generation.comfyui.api_key",
os.getenv("COMFYUI_API_KEY", ""),
)
COMFYUI_DEFAULT_WORKFLOW = """
{
"3": {
@@ -1748,7 +1817,8 @@ WHISPER_MODEL = PersistentConfig(
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
not OFFLINE_MODE and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
not OFFLINE_MODE
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)

View File

@@ -53,6 +53,11 @@ if USE_CUDA.lower() == "true":
else:
DEVICE_TYPE = "cpu"
try:
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
DEVICE_TYPE = "mps"
except Exception:
pass
####################################
# LOGGING
@@ -103,8 +108,6 @@ WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
@@ -315,6 +318,11 @@ RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
ENABLE_REALTIME_CHAT_SAVE = (
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
)
####################################
# REDIS
####################################
@@ -396,3 +404,6 @@ else:
####################################
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"

View File

@@ -55,7 +55,7 @@ def handle_peewee_migration(DATABASE_URL):
try:
# Replace the postgresql:// with postgres:// to handle the peewee migration
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
migrate_dir = OPEN_WEBUI_DIR / "apps" / "webui" / "internal" / "migrations"
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
router = Router(db, logger=log, migrate_dir=migrate_dir)
router.run()
db.close()

View File

@@ -18,6 +18,8 @@ from typing import Optional
from aiocache import cached
import aiohttp
import requests
from fastapi import (
Depends,
FastAPI,
@@ -27,7 +29,12 @@ from fastapi import (
Request,
UploadFile,
status,
applications,
BackgroundTasks,
)
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
@@ -51,6 +58,7 @@ from open_webui.routers import (
pipelines,
tasks,
auths,
channels,
chats,
folders,
configs,
@@ -96,6 +104,7 @@ from open_webui.config import (
AUTOMATIC1111_SAMPLER,
AUTOMATIC1111_SCHEDULER,
COMFYUI_BASE_URL,
COMFYUI_API_KEY,
COMFYUI_WORKFLOW,
COMFYUI_WORKFLOW_NODES,
ENABLE_IMAGE_GENERATION,
@@ -171,10 +180,13 @@ from open_webui.config import (
MOJEEK_SEARCH_API_KEY,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
GOOGLE_DRIVE_CLIENT_ID,
GOOGLE_DRIVE_API_KEY,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
ENABLE_RAG_WEB_SEARCH,
ENABLE_GOOGLE_DRIVE_INTEGRATION,
UPLOAD_DIR,
# WebUI
WEBUI_AUTH,
@@ -187,6 +199,9 @@ from open_webui.config import (
ENABLE_SIGNUP,
ENABLE_LOGIN_FORM,
ENABLE_API_KEY,
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
API_KEY_ALLOWED_ENDPOINTS,
ENABLE_CHANNELS,
ENABLE_COMMUNITY_SHARING,
ENABLE_MESSAGE_RATING,
ENABLE_EVALUATION_ARENA_MODELS,
@@ -226,6 +241,7 @@ from open_webui.config import (
CORS_ALLOW_ORIGIN,
DEFAULT_LOCALE,
OAUTH_PROVIDERS,
WEBUI_URL,
# Admin
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
@@ -251,13 +267,13 @@ from open_webui.env import (
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
WEBUI_URL,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
ENABLE_WEBSOCKET_SUPPORT,
BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START,
OFFLINE_MODE,
@@ -285,6 +301,7 @@ from open_webui.utils.auth import (
from open_webui.utils.oauth import oauth_manager
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
if SAFE_MODE:
print("SAFE MODE ENABLED")
@@ -374,9 +391,15 @@ app.state.OPENAI_MODELS = {}
#
########################################
app.state.config.WEBUI_URL = WEBUI_URL
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = (
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS
)
app.state.config.API_KEY_ALLOWED_ENDPOINTS = API_KEY_ALLOWED_ENDPOINTS
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
@@ -393,6 +416,8 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
@@ -477,6 +502,7 @@ app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
@@ -504,6 +530,22 @@ app.state.rf = None
app.state.YOUTUBE_LOADER_TRANSLATION = None
try:
app.state.ef = get_ef(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
app.state.rf = get_rf(
app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
except Exception as e:
log.error(f"Error updating models: {e}")
pass
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
@@ -521,21 +563,6 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
try:
app.state.ef = get_ef(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
app.state.rf = get_rf(
app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
except Exception as e:
log.error(f"Error updating models: {e}")
pass
########################################
#
@@ -557,6 +584,7 @@ app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.COMFYUI_API_KEY = COMFYUI_API_KEY
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
@@ -722,6 +750,8 @@ app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"])
app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"])
app.include_router(users.router, prefix="/api/v1/users", tags=["users"])
app.include_router(channels.router, prefix="/api/v1/channels", tags=["channels"])
app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"])
app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
@@ -810,11 +840,11 @@ async def chat_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: bool = False,
):
if not request.app.state.MODELS:
await get_all_models(request)
tasks = form_data.pop("background_tasks", None)
try:
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
@@ -822,13 +852,26 @@ async def chat_completion(
model = request.app.state.MODELS[model_id]
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
form_data, events = await process_chat_payload(request, form_data, user, model)
metadata = {
"user_id": user.id,
"chat_id": form_data.pop("chat_id", None),
"message_id": form_data.pop("id", None),
"session_id": form_data.pop("session_id", None),
"tool_ids": form_data.get("tool_ids", None),
"files": form_data.get("files", None),
"features": form_data.get("features", None),
}
form_data["metadata"] = metadata
form_data, events = await process_chat_payload(
request, form_data, metadata, user, model
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -836,10 +879,10 @@ async def chat_completion(
)
try:
response = await chat_completion_handler(
request, form_data, user, bypass_filter
response = await chat_completion_handler(request, form_data, user)
return await process_chat_response(
request, response, form_data, user, events, metadata, tasks
)
return await process_chat_response(response, events)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -878,6 +921,20 @@ async def chat_action(
)
@app.post("/api/tasks/stop/{task_id}")
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
try:
result = await stop_task(task_id) # Use the function from tasks.py
return result
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@app.get("/api/tasks")
async def list_tasks_endpoint(user=Depends(get_verified_user)):
return {"tasks": list_tasks()} # Use the function from tasks.py
##################################
#
# Config Endpoints
@@ -925,9 +982,12 @@ async def get_app_config(request: Request):
"enable_api_key": app.state.config.ENABLE_API_KEY,
"enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
**(
{
"enable_channels": app.state.config.ENABLE_CHANNELS,
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
@@ -938,6 +998,10 @@ async def get_app_config(request: Request):
else {}
),
},
"google_drive": {
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
"api_key": GOOGLE_DRIVE_API_KEY.value,
},
**(
{
"default_models": app.state.config.DEFAULT_MODELS,
@@ -1082,9 +1146,9 @@ async def get_opensearch_xml():
<ShortName>{WEBUI_NAME}</ShortName>
<Description>Search {WEBUI_NAME}</Description>
<InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/static/favicon.png</Image>
<Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
<moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
<Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image>
<Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/>
<moz:SearchForm>{app.state.config.WEBUI_URL}</moz:SearchForm>
</OpenSearchDescription>
"""
return Response(content=xml_content, media_type="application/xml")
@@ -1104,6 +1168,19 @@ async def healthcheck_with_db():
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
def swagger_ui_html(*args, **kwargs):
return get_swagger_ui_html(
*args,
**kwargs,
swagger_js_url="/static/swagger-ui/swagger-ui-bundle.js",
swagger_css_url="/static/swagger-ui/swagger-ui.css",
swagger_favicon_url="/static/swagger-ui/favicon.png",
)
applications.get_swagger_ui_html = swagger_ui_html
if os.path.exists(FRONTEND_BUILD_DIR):
mimetypes.add_type("text/javascript", ".js")
app.mount(

View File

@@ -0,0 +1,70 @@
"""Update message & channel tables
Revision ID: 3781e22d8b01
Revises: 7826ab40b532
Create Date: 2024-12-30 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "3781e22d8b01"
down_revision = "7826ab40b532"
branch_labels = None
depends_on = None
def upgrade():
# Add 'type' column to the 'channel' table
op.add_column(
"channel",
sa.Column(
"type",
sa.Text(),
nullable=True,
),
)
# Add 'parent_id' column to the 'message' table for threads
op.add_column(
"message",
sa.Column("parent_id", sa.Text(), nullable=True),
)
op.create_table(
"message_reaction",
sa.Column(
"id", sa.Text(), nullable=False, primary_key=True, unique=True
), # Unique reaction ID
sa.Column("user_id", sa.Text(), nullable=False), # User who reacted
sa.Column(
"message_id", sa.Text(), nullable=False
), # Message that was reacted to
sa.Column(
"name", sa.Text(), nullable=False
), # Reaction name (e.g. "thumbs_up")
sa.Column(
"created_at", sa.BigInteger(), nullable=True
), # Timestamp of when the reaction was added
)
op.create_table(
"channel_member",
sa.Column(
"id", sa.Text(), nullable=False, primary_key=True, unique=True
), # Record ID for the membership row
sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel
sa.Column("user_id", sa.Text(), nullable=False), # Associated user
sa.Column(
"created_at", sa.BigInteger(), nullable=True
), # Timestamp of when the user joined the channel
)
def downgrade():
# Revert 'type' column addition to the 'channel' table
op.drop_column("channel", "type")
op.drop_column("message", "parent_id")
op.drop_table("message_reaction")
op.drop_table("channel_member")

View File

@@ -0,0 +1,48 @@
"""Add channel table
Revision ID: 57c599a3cb57
Revises: 922e7a387820
Create Date: 2024-12-22 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "57c599a3cb57"
down_revision = "922e7a387820"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"channel",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text()),
sa.Column("name", sa.Text()),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
op.create_table(
"message",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text()),
sa.Column("channel_id", sa.Text(), nullable=True),
sa.Column("content", sa.Text()),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
def downgrade():
op.drop_table("channel")
op.drop_table("message")

View File

@@ -0,0 +1,26 @@
"""Update file table
Revision ID: 7826ab40b532
Revises: 57c599a3cb57
Create Date: 2024-12-23 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "7826ab40b532"
down_revision = "57c599a3cb57"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
"file",
sa.Column("access_control", sa.JSON(), nullable=True),
)
def downgrade():
op.drop_column("file", "access_control")

View File

@@ -0,0 +1,136 @@
import json
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.utils.access_control import has_access
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Channel DB Schema
####################
class Channel(Base):
__tablename__ = "channel"
id = Column(Text, primary_key=True)
user_id = Column(Text)
type = Column(Text, nullable=True)
name = Column(Text)
description = Column(Text, nullable=True)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class ChannelModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
type: Optional[str] = None
name: str
description: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class ChannelForm(BaseModel):
name: str
description: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
class ChannelTable:
def insert_new_channel(
self, type: Optional[str], form_data: ChannelForm, user_id: str
) -> Optional[ChannelModel]:
with get_db() as db:
channel = ChannelModel(
**{
**form_data.model_dump(),
"type": type,
"name": form_data.name.lower(),
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
}
)
new_channel = Channel(**channel.model_dump())
db.add(new_channel)
db.commit()
return channel
def get_channels(self) -> list[ChannelModel]:
with get_db() as db:
channels = db.query(Channel).all()
return [ChannelModel.model_validate(channel) for channel in channels]
def get_channels_by_user_id(
self, user_id: str, permission: str = "read"
) -> list[ChannelModel]:
channels = self.get_channels()
return [
channel
for channel in channels
if channel.user_id == user_id
or has_access(user_id, permission, channel.access_control)
]
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first()
return ChannelModel.model_validate(channel) if channel else None
def update_channel_by_id(
self, id: str, form_data: ChannelForm
) -> Optional[ChannelModel]:
with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first()
if not channel:
return None
channel.name = form_data.name
channel.data = form_data.data
channel.meta = form_data.meta
channel.access_control = form_data.access_control
channel.updated_at = int(time.time_ns())
db.commit()
return ChannelModel.model_validate(channel) if channel else None
def delete_channel_by_id(self, id: str):
with get_db() as db:
db.query(Channel).filter(Channel.id == id).delete()
db.commit()
return True
Channels = ChannelTable()

View File

@@ -168,6 +168,100 @@ class ChatTable:
except Exception:
return None
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
chat["title"] = title
return self.update_chat_by_id(id, chat)
def update_chat_tags_by_id(
self, id: str, tags: list[str], user
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
self.delete_all_tags_by_id_and_user_id(id, user.id)
for tag in chat.meta.get("tags", []):
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
for tag_name in tags:
if tag_name.lower() == "none":
continue
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
return self.get_chat_by_id(id)
def get_chat_title_by_id(self, id: str) -> Optional[str]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("title", "New Chat")
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("history", {}).get("messages", {}) or {}
def get_message_by_id_and_message_id(
self, id: str, message_id: str
) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("history", {}).get("messages", {}).get(message_id, {})
def upsert_message_to_chat_by_id_and_message_id(
self, id: str, message_id: str, message: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
if message_id in history.get("messages", {}):
history["messages"][message_id] = {
**history["messages"][message_id],
**message,
}
else:
history["messages"][message_id] = message
history["currentId"] = message_id
chat["history"] = history
return self.update_chat_by_id(id, chat)
def add_message_status_to_chat_by_id_and_message_id(
self, id: str, message_id: str, status: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
if message_id in history.get("messages", {}):
status_history = history["messages"][message_id].get("statusHistory", [])
status_history.append(status)
history["messages"][message_id]["statusHistory"] = status_history
chat["history"] = history
return self.update_chat_by_id(id, chat)
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share
@@ -375,6 +469,8 @@ class ChatTable:
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
# it is possible that the shared link was deleted. hence,
# we check if the chat is still shared by checkng if a chat with the share_id exists
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:

View File

@@ -27,6 +27,8 @@ class File(Base):
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
@@ -44,6 +46,8 @@ class FileModel(BaseModel):
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: Optional[int] # timestamp in epoch
updated_at: Optional[int] # timestamp in epoch
@@ -90,6 +94,7 @@ class FileForm(BaseModel):
path: str
data: dict = {}
meta: dict = {}
access_control: Optional[dict] = None
class FilesTable:

View File

@@ -146,6 +146,13 @@ class GroupTable:
except Exception:
return None
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
group = self.get_group_by_id(id)
if group:
return group.user_ids
else:
return None
def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]:

View File

@@ -0,0 +1,279 @@
import json
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Message DB Schema
####################
class MessageReaction(Base):
__tablename__ = "message_reaction"
id = Column(Text, primary_key=True)
user_id = Column(Text)
message_id = Column(Text)
name = Column(Text)
created_at = Column(BigInteger)
class MessageReactionModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
message_id: str
name: str
created_at: int # timestamp in epoch
class Message(Base):
__tablename__ = "message"
id = Column(Text, primary_key=True)
user_id = Column(Text)
channel_id = Column(Text, nullable=True)
parent_id = Column(Text, nullable=True)
content = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
created_at = Column(BigInteger) # time_ns
updated_at = Column(BigInteger) # time_ns
class MessageModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
channel_id: Optional[str] = None
parent_id: Optional[str] = None
content: str
data: Optional[dict] = None
meta: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class MessageForm(BaseModel):
content: str
parent_id: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
class Reactions(BaseModel):
name: str
user_ids: list[str]
count: int
class MessageResponse(MessageModel):
latest_reply_at: Optional[int]
reply_count: int
reactions: list[Reactions]
class MessageTable:
def insert_new_message(
self, form_data: MessageForm, channel_id: str, user_id: str
) -> Optional[MessageModel]:
with get_db() as db:
id = str(uuid.uuid4())
ts = int(time.time_ns())
message = MessageModel(
**{
"id": id,
"user_id": user_id,
"channel_id": channel_id,
"parent_id": form_data.parent_id,
"content": form_data.content,
"data": form_data.data,
"meta": form_data.meta,
"created_at": ts,
"updated_at": ts,
}
)
result = Message(**message.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return MessageModel.model_validate(result) if result else None
def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
with get_db() as db:
message = db.get(Message, id)
if not message:
return None
reactions = self.get_reactions_by_message_id(id)
replies = self.get_replies_by_message_id(id)
return MessageResponse(
**{
**MessageModel.model_validate(message).model_dump(),
"latest_reply_at": replies[0].created_at if replies else None,
"reply_count": len(replies),
"reactions": reactions,
}
)
def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
with get_db() as db:
all_messages = (
db.query(Message)
.filter_by(parent_id=id)
.order_by(Message.created_at.desc())
.all()
)
return [MessageModel.model_validate(message) for message in all_messages]
def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
with get_db() as db:
return [
message.user_id
for message in db.query(Message).filter_by(parent_id=id).all()
]
def get_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]:
with get_db() as db:
all_messages = (
db.query(Message)
.filter_by(channel_id=channel_id, parent_id=None)
.order_by(Message.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [MessageModel.model_validate(message) for message in all_messages]
def get_messages_by_parent_id(
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]:
with get_db() as db:
message = db.get(Message, parent_id)
if not message:
return []
all_messages = (
db.query(Message)
.filter_by(channel_id=channel_id, parent_id=parent_id)
.order_by(Message.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
# If length of all_messages is less than limit, then add the parent message
if len(all_messages) < limit:
all_messages.append(message)
return [MessageModel.model_validate(message) for message in all_messages]
def update_message_by_id(
self, id: str, form_data: MessageForm
) -> Optional[MessageModel]:
with get_db() as db:
message = db.get(Message, id)
message.content = form_data.content
message.data = form_data.data
message.meta = form_data.meta
message.updated_at = int(time.time_ns())
db.commit()
db.refresh(message)
return MessageModel.model_validate(message) if message else None
def add_reaction_to_message(
self, id: str, user_id: str, name: str
) -> Optional[MessageReactionModel]:
with get_db() as db:
reaction_id = str(uuid.uuid4())
reaction = MessageReactionModel(
id=reaction_id,
user_id=user_id,
message_id=id,
name=name,
created_at=int(time.time_ns()),
)
result = MessageReaction(**reaction.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return MessageReactionModel.model_validate(result) if result else None
def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
with get_db() as db:
all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
reactions = {}
for reaction in all_reactions:
if reaction.name not in reactions:
reactions[reaction.name] = {
"name": reaction.name,
"user_ids": [],
"count": 0,
}
reactions[reaction.name]["user_ids"].append(reaction.user_id)
reactions[reaction.name]["count"] += 1
return [Reactions(**reaction) for reaction in reactions.values()]
def remove_reaction_by_id_and_user_id_and_name(
self, id: str, user_id: str, name: str
) -> bool:
with get_db() as db:
db.query(MessageReaction).filter_by(
message_id=id, user_id=user_id, name=name
).delete()
db.commit()
return True
def delete_reactions_by_id(self, id: str) -> bool:
with get_db() as db:
db.query(MessageReaction).filter_by(message_id=id).delete()
db.commit()
return True
def delete_replies_by_id(self, id: str) -> bool:
with get_db() as db:
db.query(Message).filter_by(parent_id=id).delete()
db.commit()
return True
def delete_message_by_id(self, id: str) -> bool:
with get_db() as db:
db.query(Message).filter_by(id=id).delete()
# Delete all reactions to this message
db.query(MessageReaction).filter_by(message_id=id).delete()
db.commit()
return True
Messages = MessageTable()

View File

@@ -70,6 +70,13 @@ class UserResponse(BaseModel):
profile_image_url: str
class UserNameResponse(BaseModel):
id: str
name: str
role: str
profile_image_url: str
class UserRoleUpdateForm(BaseModel):
id: str
role: str
@@ -147,13 +154,25 @@ class UsersTable:
except Exception:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> list[UserModel]:
def get_users(
self, skip: Optional[int] = None, limit: Optional[int] = None
) -> list[UserModel]:
with get_db() as db:
users = (
db.query(User)
# .offset(skip).limit(limit)
.all()
)
query = db.query(User).order_by(User.created_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
users = query.all()
return [UserModel.model_validate(user) for user in users]
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]:
@@ -168,6 +187,22 @@ class UsersTable:
except Exception:
return None
def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
try:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
if user.settings is None:
return None
else:
return (
user.settings.get("ui", {})
.get("notifications", {})
.get("webhook_url", None)
)
except Exception:
return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try:
with get_db() as db:

View File

@@ -14,7 +14,7 @@ from langchain_core.documents import Document
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -70,7 +70,9 @@ def query_doc(
limit=k,
)
log.info(f"query_doc:result {result.ids} {result.metadatas}")
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
print(e)
@@ -373,6 +375,9 @@ def get_model_path(model: str, update_model: bool = False):
local_files_only = not update_model
if OFFLINE_MODE:
local_files_only = True
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,

View File

@@ -5,6 +5,7 @@ from sqlalchemy import (
create_engine,
Column,
Integer,
MetaData,
select,
text,
Text,
@@ -19,9 +20,9 @@ from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import PGVECTOR_DB_URL
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
VECTOR_LENGTH = 1536
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
Base = declarative_base()
@@ -56,6 +57,9 @@ class PgvectorClient:
# Ensure the pgvector extension is available
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
# Check vector length consistency
self.check_vector_length()
# Create the tables if they do not exist
# Base.metadata.create_all requires a bind (engine or connection)
# Get the connection from the session
@@ -82,6 +86,38 @@ class PgvectorClient:
print(f"Error during initialization: {e}")
raise
def check_vector_length(self) -> None:
"""
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
Raises an exception if there is a mismatch.
"""
metadata = MetaData()
metadata.reflect(bind=self.session.bind, only=["document_chunk"])
if "document_chunk" in metadata.tables:
document_chunk_table = metadata.tables["document_chunk"]
if "vector" in document_chunk_table.columns:
vector_column = document_chunk_table.columns["vector"]
vector_type = vector_column.type
if isinstance(vector_type, Vector):
db_vector_length = vector_type.dim
if db_vector_length != VECTOR_LENGTH:
raise Exception(
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
"Cannot change vector size after initialization without migrating the data."
)
else:
raise Exception(
"The 'vector' column exists but is not of type 'Vector'."
)
else:
raise Exception(
"The 'vector' column does not exist in the 'document_chunk' table."
)
else:
# Table does not exist yet; no action needed
pass
def adjust_vector_length(self, vector: List[float]) -> List[float]:
# Adjust vector to have length VECTOR_LENGTH
current_length = len(vector)

View File

@@ -683,7 +683,7 @@
"age": "October 29, 2022",
"extra_snippets": [
"You can pass many options to the configure script; run ./configure --help to find out more. On macOS case-insensitive file systems and on Cygwin, the executable is called python.exe; elsewhere it's just python.",
"Building a complete Python installation requires the use of various additional third-party libraries, depending on your build platform and configure options. Not all standard library modules are buildable or useable on all platforms. Refer to the Install dependencies section of the Developer Guide for current detailed information on dependencies for various Linux distributions and macOS.",
"Building a complete Python installation requires the use of various additional third-party libraries, depending on your build platform and configure options. Not all standard library modules are buildable or usable on all platforms. Refer to the Install dependencies section of the Developer Guide for current detailed information on dependencies for various Linux distributions and macOS.",
"To get an optimized build of Python, configure --enable-optimizations before you run make. This sets the default make targets up to enable Profile Guided Optimization (PGO) and may be used to auto-enable Link Time Optimization (LTO) on some platforms. For more details, see the sections below.",
"Copyright © 2001-2024 Python Software Foundation. All rights reserved."
]

View File

@@ -82,15 +82,15 @@ class SafeWebBaseLoader(WebBaseLoader):
def get_web_loader(
url: Union[str, Sequence[str]],
urls: Union[str, Sequence[str]],
verify_ssl: bool = True,
requests_per_second: int = 2,
):
# Check if the URL is valid
if not validate_url(url):
if not validate_url(urls):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return SafeWebBaseLoader(
url,
urls,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,

View File

@@ -218,7 +218,7 @@ async def update_audio_config(
}
def load_speech_pipeline():
def load_speech_pipeline(request):
from transformers import pipeline
from datasets import load_dataset
@@ -236,7 +236,11 @@ def load_speech_pipeline():
@router.post("/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body()
name = hashlib.sha256(body).hexdigest()
name = hashlib.sha256(
body
+ str(request.app.state.config.TTS_ENGINE).encode("utf-8")
+ str(request.app.state.config.TTS_MODEL).encode("utf-8")
).hexdigest()
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
@@ -256,10 +260,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
payload["model"] = request.app.state.config.TTS_MODEL
try:
# print(payload)
async with aiohttp.ClientSession() as session:
async with session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
data=payload,
json=payload,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
@@ -281,7 +286,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
await f.write(await r.read())
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(json.loads(body.decode("utf-8"))))
await f.write(json.dumps(payload))
return FileResponse(file_path)
@@ -292,6 +297,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
try:
if r.status != 200:
res = await r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
@@ -305,7 +311,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
voice_id = payload.get("voice", "")
if voice_id not in get_available_voices():
if voice_id not in get_available_voices(request):
raise HTTPException(
status_code=400,
detail="Invalid voice id",
@@ -332,7 +338,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
await f.write(await r.read())
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(json.loads(body.decode("utf-8"))))
await f.write(json.dumps(payload))
return FileResponse(file_path)
@@ -384,6 +390,9 @@ async def speech(request: Request, user=Depends(get_verified_user)):
async with aiofiles.open(file_path, "wb") as f:
await f.write(await r.read())
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(payload))
return FileResponse(file_path)
except Exception as e:
@@ -414,7 +423,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
import torch
import soundfile as sf
load_speech_pipeline()
load_speech_pipeline(request)
embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
@@ -436,8 +445,9 @@ async def speech(request: Request, user=Depends(get_verified_user)):
)
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(payload))
return FileResponse(file_path)

View File

@@ -547,7 +547,6 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try:
print(form_data)
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
form_data.email.lower(),
@@ -614,8 +613,12 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
async def get_admin_config(request: Request, user=Depends(get_admin_user)):
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"WEBUI_URL": request.app.state.config.WEBUI_URL,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@@ -625,8 +628,12 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
class AdminConfig(BaseModel):
SHOW_ADMIN_DETAILS: bool
WEBUI_URL: str
ENABLE_SIGNUP: bool
ENABLE_API_KEY: bool
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool
API_KEY_ALLOWED_ENDPOINTS: str
ENABLE_CHANNELS: bool
DEFAULT_USER_ROLE: str
JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool
@@ -638,8 +645,18 @@ async def update_admin_config(
request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
):
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
request.app.state.config.WEBUI_URL = form_data.WEBUI_URL
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = (
form_data.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS
)
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS = (
form_data.API_KEY_ALLOWED_ENDPOINTS
)
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
@@ -657,8 +674,12 @@ async def update_admin_config(
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"WEBUI_URL": request.app.state.config.WEBUI_URL,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,

View File

@@ -0,0 +1,710 @@
import json
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
from pydantic import BaseModel
from open_webui.socket.main import sio, get_user_ids_from_room
from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels, ChannelModel, ChannelForm
from open_webui.models.messages import (
Messages,
MessageModel,
MessageResponse,
MessageForm,
)
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, get_users_with_access
from open_webui.utils.webhook import post_webhook
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# GetChatList
############################
@router.get("/", response_model=list[ChannelModel])
async def get_channels(user=Depends(get_verified_user)):
if user.role == "admin":
return Channels.get_channels()
else:
return Channels.get_channels_by_user_id(user.id)
############################
# CreateNewChannel
############################
@router.post("/create", response_model=Optional[ChannelModel])
async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
try:
channel = Channels.insert_new_channel(None, form_data, user.id)
return ChannelModel(**channel.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChannelById
############################
@router.get("/{id}", response_model=Optional[ChannelModel])
async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
return ChannelModel(**channel.model_dump())
############################
# UpdateChannelById
############################
@router.post("/{id}/update", response_model=Optional[ChannelModel])
async def update_channel_by_id(
id: str, form_data: ChannelForm, user=Depends(get_admin_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
try:
channel = Channels.update_channel_by_id(id, form_data)
return ChannelModel(**channel.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# DeleteChannelById
############################
@router.delete("/{id}/delete", response_model=bool)
async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
try:
Channels.delete_channel_by_id(id)
return True
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChannelMessages
############################
class MessageUserResponse(MessageResponse):
user: UserNameResponse
@router.get("/{id}/messages", response_model=list[MessageUserResponse])
async def get_channel_messages(
id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message_list = Messages.get_messages_by_channel_id(id, skip, limit)
users = {}
messages = []
for message in message_list:
if message.user_id not in users:
user = Users.get_user_by_id(message.user_id)
users[message.user_id] = user
replies = Messages.get_replies_by_message_id(message.id)
latest_reply_at = replies[0].created_at if replies else None
messages.append(
MessageUserResponse(
**{
**message.model_dump(),
"reply_count": len(replies),
"latest_reply_at": latest_reply_at,
"reactions": Messages.get_reactions_by_message_id(message.id),
"user": UserNameResponse(**users[message.user_id].model_dump()),
}
)
)
return messages
############################
# PostNewMessage
############################
async def send_notification(webui_url, channel, message, active_user_ids):
users = get_users_with_access("read", channel.access_control)
for user in users:
if user.id in active_user_ids:
continue
else:
if user.settings:
webhook_url = user.settings.ui.get("notifications", {}).get(
"webhook_url", None
)
if webhook_url:
post_webhook(
webhook_url,
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
{
"action": "channel",
"message": message.content,
"title": channel.name,
"url": f"{webui_url}/channels/{channel.id}",
},
)
@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
async def post_new_message(
request: Request,
id: str,
form_data: MessageForm,
background_tasks: BackgroundTasks,
user=Depends(get_verified_user),
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
message = Messages.insert_new_message(form_data, channel.id, user.id)
if message:
event_data = {
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message",
"data": MessageUserResponse(
**{
**message.model_dump(),
"reply_count": 0,
"latest_reply_at": None,
"reactions": Messages.get_reactions_by_message_id(
message.id
),
"user": UserNameResponse(**user.model_dump()),
}
).model_dump(),
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
}
await sio.emit(
"channel-events",
event_data,
to=f"channel:{channel.id}",
)
if message.parent_id:
# If this message is a reply, emit to the parent message as well
parent_message = Messages.get_message_by_id(message.parent_id)
if parent_message:
await sio.emit(
"channel-events",
{
"channel_id": channel.id,
"message_id": parent_message.id,
"data": {
"type": "message:reply",
"data": MessageUserResponse(
**{
**parent_message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(
parent_message.user_id
).model_dump()
),
}
).model_dump(),
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
},
to=f"channel:{channel.id}",
)
active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
background_tasks.add_task(
send_notification,
request.app.state.config.WEBUI_URL,
channel,
message,
active_user_ids,
)
return MessageModel(**message.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChannelMessage
############################
@router.get("/{id}/messages/{message_id}", response_model=Optional[MessageUserResponse])
async def get_channel_message(
id: str, message_id: str, user=Depends(get_verified_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if message.channel_id != id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
return MessageUserResponse(
**{
**message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(message.user_id).model_dump()
),
}
)
############################
# GetChannelThreadMessages
############################
@router.get(
"/{id}/messages/{message_id}/thread", response_model=list[MessageUserResponse]
)
async def get_channel_thread_messages(
id: str,
message_id: str,
skip: int = 0,
limit: int = 50,
user=Depends(get_verified_user),
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit)
users = {}
messages = []
for message in message_list:
if message.user_id not in users:
user = Users.get_user_by_id(message.user_id)
users[message.user_id] = user
messages.append(
MessageUserResponse(
**{
**message.model_dump(),
"reply_count": 0,
"latest_reply_at": None,
"reactions": Messages.get_reactions_by_message_id(message.id),
"user": UserNameResponse(**users[message.user_id].model_dump()),
}
)
)
return messages
############################
# UpdateMessageById
############################
@router.post(
"/{id}/messages/{message_id}/update", response_model=Optional[MessageModel]
)
async def update_message_by_id(
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if message.channel_id != id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
try:
message = Messages.update_message_by_id(message_id, form_data)
message = Messages.get_message_by_id(message_id)
if message:
await sio.emit(
"channel-events",
{
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message:update",
"data": MessageUserResponse(
**{
**message.model_dump(),
"user": UserNameResponse(
**user.model_dump()
).model_dump(),
}
).model_dump(),
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
},
to=f"channel:{channel.id}",
)
return MessageModel(**message.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# AddReactionToMessage
############################
class ReactionForm(BaseModel):
name: str
@router.post("/{id}/messages/{message_id}/reactions/add", response_model=bool)
async def add_reaction_to_message(
id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if message.channel_id != id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
try:
Messages.add_reaction_to_message(message_id, user.id, form_data.name)
message = Messages.get_message_by_id(message_id)
await sio.emit(
"channel-events",
{
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message:reaction:add",
"data": {
**message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(message.user_id).model_dump()
).model_dump(),
"name": form_data.name,
},
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
},
to=f"channel:{channel.id}",
)
return True
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# RemoveReactionById
############################
@router.post("/{id}/messages/{message_id}/reactions/remove", response_model=bool)
async def remove_reaction_by_id_and_user_id_and_name(
id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if message.channel_id != id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
try:
Messages.remove_reaction_by_id_and_user_id_and_name(
message_id, user.id, form_data.name
)
message = Messages.get_message_by_id(message_id)
await sio.emit(
"channel-events",
{
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message:reaction:remove",
"data": {
**message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(message.user_id).model_dump()
).model_dump(),
"name": form_data.name,
},
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
},
to=f"channel:{channel.id}",
)
return True
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# DeleteMessageById
############################
@router.delete("/{id}/messages/{message_id}/delete", response_model=bool)
async def delete_message_by_id(
id: str, message_id: str, user=Depends(get_verified_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if message.channel_id != id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
try:
Messages.delete_message_by_id(message_id)
await sio.emit(
"channel-events",
{
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message:delete",
"data": {
**message.model_dump(),
"user": UserNameResponse(**user.model_dump()).model_dump(),
},
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
},
to=f"channel:{channel.id}",
)
if message.parent_id:
# If this message is a reply, emit to the parent message as well
parent_message = Messages.get_message_by_id(message.parent_id)
if parent_message:
await sio.emit(
"channel-events",
{
"channel_id": channel.id,
"message_id": parent_message.id,
"data": {
"type": "message:reply",
"data": MessageUserResponse(
**{
**parent_message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(
parent_message.user_id
).model_dump()
),
}
).model_dump(),
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
},
to=f"channel:{channel.id}",
)
return True
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)

View File

@@ -463,6 +463,30 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
)
############################
# CloneSharedChatById
############################
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_share_id(id)
if chat:
updated_chat = {
**chat.chat,
"originalChatId": chat.id,
"branchPointMessageId": chat.chat["history"]["currentId"],
"title": f"Clone of {chat.title}",
}
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# ArchiveChat
############################

View File

@@ -226,9 +226,16 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
# Handle Unicode filenames
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename) # RFC5987 encoding
headers = {
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
}
headers = {}
if file.meta.get("content_type") not in [
"application/pdf",
"text/plain",
]:
headers = {
**headers,
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
}
return FileResponse(file_path, headers=headers)
@@ -341,7 +348,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
result = Files.delete_file_by_id(id)
if result:
try:
Storage.delete_file(file.filename)
Storage.delete_file(file.path)
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")

View File

@@ -56,6 +56,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
},
"comfyui": {
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
@@ -77,6 +78,7 @@ class Automatic1111ConfigForm(BaseModel):
class ComfyUIConfigForm(BaseModel):
COMFYUI_BASE_URL: str
COMFYUI_API_KEY: str
COMFYUI_WORKFLOW: str
COMFYUI_WORKFLOW_NODES: list[dict]
@@ -148,6 +150,7 @@ async def update_config(
},
"comfyui": {
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
@@ -197,7 +200,7 @@ def set_image_model(request: Request, model: str):
log.info(f"Setting image model to {model}")
request.app.state.config.IMAGE_GENERATION_MODEL = model
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
api_auth = get_automatic1111_api_auth()
api_auth = get_automatic1111_api_auth(request)
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": api_auth},
@@ -233,7 +236,7 @@ def get_image_model(request):
try:
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth()},
headers={"authorization": get_automatic1111_api_auth(request)},
)
options = r.json()
return options["sd_model_checkpoint"]
@@ -298,8 +301,12 @@ def get_models(request: Request, user=Depends(get_verified_user)):
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
)
info = r.json()
@@ -347,7 +354,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
):
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
headers={"authorization": get_automatic1111_api_auth()},
headers={"authorization": get_automatic1111_api_auth(request)},
)
models = r.json()
return list(
@@ -521,6 +528,7 @@ async def image_generations(
form_data,
user.id,
request.app.state.config.COMFYUI_BASE_URL,
request.app.state.config.COMFYUI_API_KEY,
)
log.debug(f"res: {res}")
@@ -570,7 +578,7 @@ async def image_generations(
requests.post,
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
headers={"authorization": get_automatic1111_api_auth()},
headers={"authorization": get_automatic1111_api_auth(request)},
)
res = r.json()

View File

@@ -1,5 +1,4 @@
import json
from typing import Optional, Union
from typing import List, Optional
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, status, Request
import logging
@@ -12,11 +11,16 @@ from open_webui.models.knowledge import (
)
from open_webui.models.files import Files, FileModel
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.routers.retrieval import process_file, ProcessFileForm
from open_webui.routers.retrieval import (
process_file,
ProcessFileForm,
process_files_batch,
BatchProcessFilesForm,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.auth import get_verified_user
from open_webui.utils.access_control import has_access, has_permission
@@ -415,13 +419,6 @@ def remove_file_from_knowledge_by_id(
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
)
result = VECTOR_DB_CLIENT.query(
collection_name=knowledge.id,
filter={"file_id": form_data.file_id},
)
Files.delete_file_by_id(form_data.file_id)
if knowledge:
data = knowledge.data or {}
file_ids = data.get("file_ids", [])
@@ -514,3 +511,86 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
return knowledge
############################
# AddFilesToKnowledge
############################
@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
def add_files_to_knowledge_batch(
request: Request,
id: str,
form_data: list[KnowledgeFileIdForm],
user=Depends(get_verified_user),
):
"""
Add multiple files to a knowledge base
"""
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
# Get files content
print(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = []
for form in form_data:
file = Files.get_file_by_id(form.file_id)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File {form.file_id} not found",
)
files.append(file)
# Process files
try:
result = process_files_batch(
request=request,
form_data=BatchProcessFilesForm(files=files, collection_name=id),
user=user,
)
except Exception as e:
log.error(
f"add_files_to_knowledge_batch: Exception occurred: {e}", exc_info=True
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# Add successful files to knowledge base
data = knowledge.data or {}
existing_file_ids = data.get("file_ids", [])
# Only add files that were successfully processed
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
for file_id in successful_file_ids:
if file_id not in existing_file_ids:
existing_file_ids.append(file_id)
data["file_ids"] = existing_file_ids
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
# If there were any errors, include them in the response
if result.errors:
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids),
warnings={
"message": "Some files failed to process",
"errors": error_details,
},
)
return KnowledgeFilesResponse(
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
)

View File

@@ -82,6 +82,16 @@ async def send_get_request(url, key=None):
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
async def send_post_request(
url: str,
payload: Union[str, bytes],
@@ -89,14 +99,6 @@ async def send_post_request(
key: Optional[str] = None,
content_type: Optional[str] = None,
):
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
r = None
try:
@@ -917,7 +919,7 @@ class ChatMessage(BaseModel):
class GenerateChatCompletionForm(BaseModel):
model: str
messages: list[ChatMessage]
format: Optional[str] = None
format: Optional[dict] = None
options: Optional[dict] = None
template: Optional[str] = None
stream: Optional[bool] = True

View File

@@ -533,6 +533,9 @@ async def generate_chat_completion(
user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False,
):
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
idx = 0
payload = {**form_data}
if "metadata" in payload:
@@ -545,6 +548,7 @@ async def generate_chat_completion(
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_id = model_info.base_model_id
params = model_info.params.model_dump()
payload = apply_model_params_to_body_openai(params, payload)
@@ -604,14 +608,13 @@ async def generate_chat_completion(
if is_o1:
payload = openai_o1_handler(payload)
elif "api.openai.com" not in url:
# Remove "max_tokens" from the payload for backward compatibility
if "max_tokens" in payload:
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Remove "max_completion_tokens" from the payload for backward compatibility
if "max_completion_tokens" in payload:
payload["max_tokens"] = payload["max_completion_tokens"]
del payload["max_completion_tokens"]
# TODO: check if below is needed
# if "max_tokens" in payload and "max_completion_tokens" in payload:
# del payload["max_tokens"]
if "max_tokens" in payload and "max_completion_tokens" in payload:
del payload["max_tokens"]
# Convert the modified body back to JSON
payload = json.dumps(payload)

View File

@@ -124,18 +124,14 @@ def process_pipeline_outlet_filter(request, payload, user, models):
f"{url}/{filter['id']}/filter/outlet",
headers={"Authorization": f"Bearer {key}"},
json={
"user": {
"id": user.id,
"name": user.name,
"email": user.email,
"role": user.role,
},
"body": data,
"user": user,
"body": payload,
},
)
r.raise_for_status()
data = r.json()
payload = data
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")

View File

@@ -7,7 +7,7 @@ import shutil
import uuid
from datetime import datetime
from pathlib import Path
from typing import Iterator, Optional, Sequence, Union
from typing import Iterator, List, Optional, Sequence, Union
from fastapi import (
Depends,
@@ -28,7 +28,7 @@ import tiktoken
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain_core.documents import Document
from open_webui.models.files import Files
from open_webui.models.files import FileModel, Files
from open_webui.models.knowledge import Knowledges
from open_webui.storage.provider import Storage
@@ -347,6 +347,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
"enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"content_extraction": {
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
@@ -369,6 +370,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
"drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
@@ -445,6 +447,7 @@ class WebConfig(BaseModel):
class ConfigUpdateForm(BaseModel):
pdf_extract_images: Optional[bool] = None
enable_google_drive_integration: Optional[bool] = None
file: Optional[FileConfig] = None
content_extraction: Optional[ContentExtractionConfig] = None
chunk: Optional[ChunkParamUpdateForm] = None
@@ -462,6 +465,12 @@ async def update_rag_config(
else request.app.state.config.PDF_EXTRACT_IMAGES
)
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
form_data.enable_google_drive_integration
if form_data.enable_google_drive_integration is not None
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
)
if form_data.file is not None:
request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
@@ -1247,21 +1256,22 @@ def process_web_search(
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
)
log.debug(f"web_results: {web_results}")
try:
collection_name = form_data.collection_name
if collection_name == "":
if collection_name == "" or collection_name is None:
collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[
:63
]
urls = [result.link for result in web_results]
loader = get_web_loader(
urls=urls,
urls,
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
)
docs = loader.aload()
docs = loader.load()
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
return {
@@ -1428,3 +1438,94 @@ if ENV == "dev":
@router.get("/ef/{text}")
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
return {"result": request.app.state.EMBEDDING_FUNCTION(text)}
class BatchProcessFilesForm(BaseModel):
files: List[FileModel]
collection_name: str
class BatchProcessFilesResult(BaseModel):
file_id: str
status: str
error: Optional[str] = None
class BatchProcessFilesResponse(BaseModel):
results: List[BatchProcessFilesResult]
errors: List[BatchProcessFilesResult]
@router.post("/process/files/batch")
def process_files_batch(
request: Request,
form_data: BatchProcessFilesForm,
user=Depends(get_verified_user),
) -> BatchProcessFilesResponse:
"""
Process a batch of files and save them to the vector database.
"""
results: List[BatchProcessFilesResult] = []
errors: List[BatchProcessFilesResult] = []
collection_name = form_data.collection_name
# Prepare all documents first
all_docs: List[Document] = []
for file in form_data.files:
try:
text_content = file.data.get("content", "")
docs: List[Document] = [
Document(
page_content=text_content.replace("<br/>", "\n"),
metadata={
**file.meta,
"name": file.filename,
"created_by": file.user_id,
"file_id": file.id,
"source": file.filename,
},
)
]
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
Files.update_file_data_by_id(file.id, {"content": text_content})
all_docs.extend(docs)
results.append(BatchProcessFilesResult(file_id=file.id, status="prepared"))
except Exception as e:
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
errors.append(
BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
)
# Save all documents in one batch
if all_docs:
try:
save_docs_to_vector_db(
request=request,
docs=all_docs,
collection_name=collection_name,
add=True,
)
# Update all files with collection name
for result in results:
Files.update_file_metadata_by_id(
result.file_id, {"collection_name": collection_name}
)
result.status = "completed"
except Exception as e:
log.error(
f"process_files_batch: Error saving documents to vector DB: {str(e)}"
)
for result in results:
result.status = "failed"
errors.append(
BatchProcessFilesResult(file_id=result.file_id, error=str(e))
)
return BatchProcessFilesResponse(results=results, errors=errors)

View File

@@ -186,9 +186,10 @@ async def generate_title(
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error("Exception occurred", exc_info=True)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
content={"detail": "An internal error has occurred."},
)
@@ -248,9 +249,10 @@ async def generate_chat_tags(
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error(f"Error generating chat completion: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "An internal error has occurred."},
)
@@ -393,9 +395,10 @@ async def generate_autocompletion(
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error(f"Error generating chat completion: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "An internal error has occurred."},
)
@@ -496,8 +499,8 @@ async def generate_moa_response(
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"chat_id": form_data.get("chat_id", None),
"metadata": {
"chat_id": form_data.get("chat_id", None),
"task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data,
},

View File

@@ -10,6 +10,9 @@ from open_webui.models.users import (
UserSettings,
UserUpdateForm,
)
from open_webui.socket.main import get_active_status_by_user_id
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, HTTPException, Request, status
@@ -27,7 +30,11 @@ router = APIRouter()
@router.get("/", response_model=list[UserModel])
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
async def get_users(
skip: Optional[int] = None,
limit: Optional[int] = None,
user=Depends(get_admin_user),
):
return Users.get_users(skip, limit)
@@ -192,6 +199,7 @@ async def update_user_info_by_session_user(
class UserResponse(BaseModel):
name: str
profile_image_url: str
active: Optional[bool] = None
@router.get("/{user_id}", response_model=UserResponse)
@@ -212,7 +220,13 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
user = Users.get_user_by_id(user_id)
if user:
return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
return UserResponse(
**{
"name": user.name,
"profile_image_url": user.profile_image_url,
"active": get_active_status_by_user_id(user_id),
}
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@@ -4,14 +4,17 @@ import logging
import sys
import time
from open_webui.models.users import Users
from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels
from open_webui.models.chats import Chats
from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT,
WEBSOCKET_MANAGER,
WEBSOCKET_REDIS_URL,
)
from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict
from open_webui.socket.utils import RedisDict, RedisLock
from open_webui.env import (
GLOBAL_LOG_LEVEL,
@@ -29,9 +32,7 @@ if WEBSOCKET_MANAGER == "redis":
sio = socketio.AsyncServer(
cors_allowed_origins=[],
async_mode="asgi",
transports=(
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
),
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
always_connect=True,
client_manager=mgr,
@@ -40,54 +41,77 @@ else:
sio = socketio.AsyncServer(
cors_allowed_origins=[],
async_mode="asgi",
transports=(
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
),
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
always_connect=True,
)
# Timeout duration in seconds
TIMEOUT_DURATION = 3
# Dictionary to maintain the user pool
if WEBSOCKET_MANAGER == "redis":
log.debug("Using Redis to manage websockets.")
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
clean_up_lock = RedisLock(
redis_url=WEBSOCKET_REDIS_URL,
lock_name="usage_cleanup_lock",
timeout_secs=TIMEOUT_DURATION * 2,
)
aquire_func = clean_up_lock.aquire_lock
renew_func = clean_up_lock.renew_lock
release_func = clean_up_lock.release_lock
else:
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
aquire_func = release_func = renew_func = lambda: True
async def periodic_usage_pool_cleanup():
while True:
now = int(time.time())
for model_id, connections in list(USAGE_POOL.items()):
# Creating a list of sids to remove if they have timed out
expired_sids = [
sid
for sid, details in connections.items()
if now - details["updated_at"] > TIMEOUT_DURATION
]
if not aquire_func():
log.debug("Usage pool cleanup lock already exists. Not running it.")
return
log.debug("Running periodic_usage_pool_cleanup")
try:
while True:
if not renew_func():
log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
raise Exception("Unable to renew usage pool cleanup lock.")
for sid in expired_sids:
del connections[sid]
now = int(time.time())
send_usage = False
for model_id, connections in list(USAGE_POOL.items()):
# Creating a list of sids to remove if they have timed out
expired_sids = [
sid
for sid, details in connections.items()
if now - details["updated_at"] > TIMEOUT_DURATION
]
if not connections:
log.debug(f"Cleaning up model {model_id} from usage pool")
del USAGE_POOL[model_id]
else:
USAGE_POOL[model_id] = connections
for sid in expired_sids:
del connections[sid]
# Emit updated usage information after cleaning
await sio.emit("usage", {"models": get_models_in_use()})
if not connections:
log.debug(f"Cleaning up model {model_id} from usage pool")
del USAGE_POOL[model_id]
else:
USAGE_POOL[model_id] = connections
await asyncio.sleep(TIMEOUT_DURATION)
send_usage = True
if send_usage:
# Emit updated usage information after cleaning
await sio.emit("usage", {"models": get_models_in_use()})
await asyncio.sleep(TIMEOUT_DURATION)
finally:
release_func()
app = socketio.ASGIApp(
@@ -128,20 +152,19 @@ async def connect(sid, environ, auth):
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
SESSION_POOL[sid] = user.model_dump()
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
else:
USER_POOL[user.id] = [sid]
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(USER_POOL.items())})
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
await sio.emit("usage", {"models": get_models_in_use()})
@sio.on("user-join")
async def user_join(sid, data):
# print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None
if not auth or "token" not in auth:
@@ -155,39 +178,91 @@ async def user_join(sid, data):
if not user:
return
SESSION_POOL[sid] = user.id
SESSION_POOL[sid] = user.model_dump()
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
else:
USER_POOL[user.id] = [sid]
# Join all the channels
channels = Channels.get_channels_by_user_id(user.id)
log.debug(f"{channels=}")
for channel in channels:
await sio.enter_room(sid, f"channel:{channel.id}")
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(USER_POOL.items())})
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
return {"id": user.id, "name": user.name}
@sio.on("user-count")
async def user_count(sid):
await sio.emit("user-count", {"count": len(USER_POOL.items())})
@sio.on("join-channels")
async def join_channel(sid, data):
auth = data["auth"] if "auth" in data else None
if not auth or "token" not in auth:
return
data = decode_token(auth["token"])
if data is None or "id" not in data:
return
user = Users.get_user_by_id(data["id"])
if not user:
return
# Join all the channels
channels = Channels.get_channels_by_user_id(user.id)
log.debug(f"{channels=}")
for channel in channels:
await sio.enter_room(sid, f"channel:{channel.id}")
@sio.on("chat")
async def chat(sid, data):
print("chat", sid, SESSION_POOL[sid], data)
@sio.on("channel-events")
async def channel_events(sid, data):
room = f"channel:{data['channel_id']}"
participants = sio.manager.get_participants(
namespace="/",
room=room,
)
sids = [sid for sid, _ in participants]
if sid not in sids:
return
event_data = data["data"]
event_type = event_data["type"]
if event_type == "typing":
await sio.emit(
"channel-events",
{
"channel_id": data["channel_id"],
"message_id": data.get("message_id", None),
"data": event_data,
"user": UserNameResponse(**SESSION_POOL[sid]).model_dump(),
},
room=room,
)
@sio.on("user-list")
async def user_list(sid):
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
@sio.event
async def disconnect(sid):
if sid in SESSION_POOL:
user_id = SESSION_POOL[sid]
user = SESSION_POOL[sid]
del SESSION_POOL[sid]
user_id = user["id"]
USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
await sio.emit("user-count", {"count": len(USER_POOL)})
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
else:
pass
# print(f"Unknown session ID {sid} disconnected")
@@ -195,16 +270,57 @@ async def disconnect(sid):
def get_event_emitter(request_info):
async def __event_emitter__(event_data):
await sio.emit(
"chat-events",
{
"chat_id": request_info["chat_id"],
"message_id": request_info["message_id"],
"data": event_data,
},
to=request_info["session_id"],
user_id = request_info["user_id"]
session_ids = list(
set(USER_POOL.get(user_id, []) + [request_info["session_id"]])
)
for session_id in session_ids:
await sio.emit(
"chat-events",
{
"chat_id": request_info["chat_id"],
"message_id": request_info["message_id"],
"data": event_data,
},
to=session_id,
)
if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
event_data.get("data", {}),
)
if "type" in event_data and event_data["type"] == "message":
message = Chats.get_message_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
)
content = message.get("content", "")
content += event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
if "type" in event_data and event_data["type"] == "replace":
content = event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
return __event_emitter__
@@ -222,3 +338,30 @@ def get_event_call(request_info):
return response
return __event_call__
def get_user_id_from_session_pool(sid):
user = SESSION_POOL.get(sid)
if user:
return user["id"]
return None
def get_user_ids_from_room(room):
active_session_ids = sio.manager.get_participants(
namespace="/",
room=room,
)
active_user_ids = list(
set(
[SESSION_POOL.get(session_id[0])["id"] for session_id in active_session_ids]
)
)
return active_user_ids
def get_active_status_by_user_id(user_id):
if user_id in USER_POOL:
return True
return False

View File

@@ -1,5 +1,33 @@
import json
import redis
import uuid
class RedisLock:
def __init__(self, redis_url, lock_name, timeout_secs):
self.lock_name = lock_name
self.lock_id = str(uuid.uuid4())
self.timeout_secs = timeout_secs
self.lock_obtained = False
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
def aquire_lock(self):
# nx=True will only set this key if it _hasn't_ already been set
self.lock_obtained = self.redis.set(
self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
)
return self.lock_obtained
def renew_lock(self):
# xx=True will only set this key if it _has_ already been set
return self.redis.set(
self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
)
def release_lock(self):
lock_value = self.redis.get(self.lock_name)
if lock_value and lock_value.decode("utf-8") == self.lock_id:
self.redis.delete(self.lock_name)
class RedisDict:

View File

@@ -26,8 +26,8 @@
html {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'NotoSans', 'NotoSansJP', 'NotoSansKR',
'NotoSansSC', 'Twemoji', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto,
'Helvetica Neue', Arial, sans-serif;
'NotoSansSC', 'Twemoji', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium',
Roboto, 'Helvetica Neue', Arial, sans-serif;
font-size: 14px; /* Default font size */
line-height: 1.5;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -147,8 +147,10 @@ class StorageProvider:
return self._get_file_from_s3(file_path)
return self._get_file_from_local(file_path)
def delete_file(self, filename: str) -> None:
def delete_file(self, file_path: str) -> None:
"""Deletes a file either from S3 or the local file system."""
filename = file_path.split("/")[-1]
if self.storage_provider == "s3":
self._delete_from_s3(filename)

View File

@@ -0,0 +1,61 @@
# tasks.py
import asyncio
from typing import Dict
from uuid import uuid4
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
def cleanup_task(task_id: str):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
tasks.pop(task_id, None) # Remove the task if it exists
def create_task(coroutine):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
task_id = str(uuid4()) # Generate a unique ID for the task
task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup
task.add_done_callback(lambda t: cleanup_task(task_id))
tasks[task_id] = task
return task_id, task
def get_task(task_id: str):
"""
Retrieve a task by its task ID.
"""
return tasks.get(task_id)
def list_tasks():
"""
List all currently active task IDs.
"""
return list(tasks.keys())
async def stop_task(task_id: str):
"""
Cancel a running task and remove it from the global task list.
"""
task = tasks.get(task_id)
if not task:
raise ValueError(f"Task with ID {task_id} not found.")
task.cancel() # Request task cancellation
try:
await task # Wait for the task to handle the cancellation
except asyncio.CancelledError:
# Task successfully canceled
tasks.pop(task_id, None) # Remove it from the dictionary
return {"status": True, "message": f"Task {task_id} successfully stopped."}
return {"status": False, "message": f"Failed to stop task {task_id}."}

View File

@@ -1,4 +1,5 @@
from typing import Optional, Union, List, Dict, Any
from open_webui.models.users import Users, UserModel
from open_webui.models.groups import Groups
import json
@@ -93,3 +94,24 @@ def has_access(
return user_id in permitted_user_ids or any(
group_id in permitted_group_ids for group_id in user_group_ids
)
# Get all users with access to a resource
def get_users_with_access(
type: str = "write", access_control: Optional[dict] = None
) -> List[UserModel]:
if access_control is None:
return Users.get_users()
permission_access = access_control.get(type, {})
permitted_group_ids = permission_access.get("group_ids", [])
permitted_user_ids = permission_access.get("user_ids", [])
user_ids_with_access = set(permitted_user_ids)
for group_id in permitted_group_ids:
group_user_ids = Groups.get_group_user_ids_by_id(group_id)
if group_user_ids:
user_ids_with_access.update(group_user_ids)
return Users.get_users_by_user_ids(list(user_ids_with_access))

View File

@@ -95,6 +95,20 @@ def get_current_user(
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
allowed_paths = [
path.strip()
for path in str(
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
).split(",")
]
if request.url.path not in allowed_paths:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
return get_current_user_by_api_key(token)
# auth by jwt token

View File

@@ -89,7 +89,7 @@ async def generate_chat_completion(
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in await get_all_models(request)
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
@@ -99,7 +99,7 @@ async def generate_chat_completion(
else:
model_ids = [
model["id"]
for model in await get_all_models(request)
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
@@ -114,21 +114,27 @@ async def generate_chat_completion(
yield chunk
response = await generate_chat_completion(
form_data, user, bypass_filter=True
request, form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream"
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
return {
**(await generate_chat_completion(form_data, user, bypass_filter=True)),
**(
await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
form_data, user=user, models=models
request, form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
@@ -141,6 +147,7 @@ async def generate_chat_completion(
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
)
else:
return convert_response_ollama_to_openai(response)
@@ -150,8 +157,12 @@ async def generate_chat_completion(
)
chat_completion = generate_chat_completion
async def chat_completed(request: Request, form_data: dict, user: Any):
await get_all_models(request)
if not request.app.state.MODELS:
await get_all_models(request)
models = request.app.state.MODELS
data = form_data
@@ -171,6 +182,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
)
@@ -179,6 +191,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
)
@@ -286,7 +299,8 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
if not action:
raise Exception(f"Action not found: {action_id}")
await get_all_models(request)
if not request.app.state.MODELS:
await get_all_models(request)
models = request.app.state.MODELS
data = form_data
@@ -301,6 +315,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
)
__event_call__ = get_event_call(
@@ -308,6 +323,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
)

View File

@@ -16,14 +16,16 @@ log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
default_headers = {"User-Agent": "Mozilla/5.0"}
def queue_prompt(prompt, client_id, base_url):
def queue_prompt(prompt, client_id, base_url, api_key):
log.info("queue_prompt")
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
log.debug(f"queue_prompt data: {data}")
try:
req = urllib.request.Request(
f"{base_url}/prompt", data=data, headers=default_headers
f"{base_url}/prompt",
data=data,
headers={**default_headers, "Authorization": f"Bearer {api_key}"},
)
response = urllib.request.urlopen(req).read()
return json.loads(response)
@@ -32,12 +34,13 @@ def queue_prompt(prompt, client_id, base_url):
raise e
def get_image(filename, subfolder, folder_type, base_url):
def get_image(filename, subfolder, folder_type, base_url, api_key):
log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
req = urllib.request.Request(
f"{base_url}/view?{url_values}", headers=default_headers
f"{base_url}/view?{url_values}",
headers={**default_headers, "Authorization": f"Bearer {api_key}"},
)
with urllib.request.urlopen(req) as response:
return response.read()
@@ -50,18 +53,19 @@ def get_image_url(filename, subfolder, folder_type, base_url):
return f"{base_url}/view?{url_values}"
def get_history(prompt_id, base_url):
def get_history(prompt_id, base_url, api_key):
log.info("get_history")
req = urllib.request.Request(
f"{base_url}/history/{prompt_id}", headers=default_headers
f"{base_url}/history/{prompt_id}",
headers={**default_headers, "Authorization": f"Bearer {api_key}"},
)
with urllib.request.urlopen(req) as response:
return json.loads(response.read())
def get_images(ws, prompt, client_id, base_url):
prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
def get_images(ws, prompt, client_id, base_url, api_key):
prompt_id = queue_prompt(prompt, client_id, base_url, api_key)["prompt_id"]
output_images = []
while True:
out = ws.recv()
@@ -74,7 +78,7 @@ def get_images(ws, prompt, client_id, base_url):
else:
continue # previews are binary data
history = get_history(prompt_id, base_url)[prompt_id]
history = get_history(prompt_id, base_url, api_key)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
@@ -113,7 +117,7 @@ class ComfyUIGenerateImageForm(BaseModel):
async def comfyui_generate_image(
model: str, payload: ComfyUIGenerateImageForm, client_id, base_url
model: str, payload: ComfyUIGenerateImageForm, client_id, base_url, api_key
):
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
workflow = json.loads(payload.workflow.workflow)
@@ -167,7 +171,8 @@ async def comfyui_generate_image(
try:
ws = websocket.WebSocket()
ws.connect(f"{ws_url}/ws?clientId={client_id}")
headers = {"Authorization": f"Bearer {api_key}"}
ws.connect(f"{ws_url}/ws?clientId={client_id}", header=headers)
log.info("WebSocket connection established.")
except Exception as e:
log.exception(f"Failed to connect to WebSocket server: {e}")
@@ -176,7 +181,9 @@ async def comfyui_generate_image(
try:
log.info("Sending workflow to WebSocket server.")
log.info(f"Workflow: {workflow}")
images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url)
images = await asyncio.to_thread(
get_images, ws, workflow, client_id, base_url, api_key
)
except Exception as e:
log.exception(f"Error while receiving images: {e}")
images = None

View File

@@ -2,21 +2,36 @@ import time
import logging
import sys
import asyncio
from aiocache import cached
from typing import Any, Optional
import random
import json
import inspect
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
from fastapi import Request
from fastapi import BackgroundTasks
from starlette.responses import Response, StreamingResponse
from open_webui.models.chats import Chats
from open_webui.models.users import Users
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
get_active_status_by_user_id,
)
from open_webui.routers.tasks import generate_queries
from open_webui.routers.tasks import (
generate_queries,
generate_title,
generate_chat_tags,
)
from open_webui.routers.retrieval import process_web_search, SearchForm
from open_webui.utils.webhook import post_webhook
from open_webui.models.users import UserModel
@@ -33,16 +48,25 @@ from open_webui.utils.task import (
tools_function_calling_generation_template,
)
from open_webui.utils.misc import (
get_message_list,
add_or_update_system_message,
get_last_user_message,
get_last_assistant_message,
prepend_to_first_user_message_content,
)
from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.tasks import create_task
from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
from open_webui.env import (
SRC_LOG_LEVELS,
GLOBAL_LOG_LEVEL,
BYPASS_MODEL_ACCESS_CONTROL,
ENABLE_REALTIME_CHAT_SAVE,
)
from open_webui.constants import TASKS
@@ -312,6 +336,156 @@ async def chat_completion_tools_handler(
return body, {"sources": sources}
async def chat_web_search_handler(
request: Request, form_data: dict, extra_params: dict, user
):
event_emitter = extra_params["__event_emitter__"]
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Generating search query",
"done": False,
},
}
)
messages = form_data["messages"]
user_message = get_last_user_message(messages)
queries = []
try:
res = await generate_queries(
request,
{
"model": form_data["model"],
"messages": messages,
"prompt": user_message,
"type": "web_search",
},
user,
)
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
queries = json.loads(response)
queries = queries.get("queries", [])
except Exception as e:
queries = [response]
except Exception as e:
log.exception(e)
queries = [user_message]
if len(queries) == 0:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search query generated",
"done": True,
},
}
)
return
searchQuery = queries[0]
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Searching "{{searchQuery}}"',
"query": searchQuery,
"done": False,
},
}
)
try:
# Offload process_web_search to a separate thread
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
results = await loop.run_in_executor(
executor,
lambda: process_web_search(
request,
SearchForm(
**{
"query": searchQuery,
}
),
user,
),
)
if results:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Searched {{count}} sites",
"query": searchQuery,
"urls": results["filenames"],
"done": True,
},
}
)
files = form_data.get("files", [])
files.append(
{
"collection_name": results["collection_name"],
"name": searchQuery,
"type": "web_search_results",
"urls": results["filenames"],
}
)
form_data["files"] = files
else:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"query": searchQuery,
"done": True,
"error": True,
},
}
)
except Exception as e:
log.exception(e)
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Error searching "{{searchQuery}}"',
"query": searchQuery,
"done": True,
"error": True,
},
}
)
return form_data
async def chat_completion_files_handler(
request: Request, body: dict, user: UserModel
) -> tuple[dict, dict[str, list]]:
@@ -320,6 +494,7 @@ async def chat_completion_files_handler(
if files := body.get("metadata", {}).get("files", None):
try:
queries_response = await generate_queries(
request,
{
"model": body["model"],
"messages": body["messages"],
@@ -362,19 +537,44 @@ async def chat_completion_files_handler(
return body, {"sources": sources}
async def process_chat_payload(request, form_data, user, model):
metadata = {
"chat_id": form_data.pop("chat_id", None),
"message_id": form_data.pop("id", None),
"session_id": form_data.pop("session_id", None),
"tool_ids": form_data.get("tool_ids", None),
"files": form_data.get("files", None),
}
form_data["metadata"] = metadata
def apply_params_to_form_data(form_data, model):
params = form_data.pop("params", {})
if model.get("ollama"):
form_data["options"] = params
if "format" in params:
form_data["format"] = params["format"]
if "keep_alive" in params:
form_data["keep_alive"] = params["keep_alive"]
else:
if "seed" in params:
form_data["seed"] = params["seed"]
if "stop" in params:
form_data["stop"] = params["stop"]
if "temperature" in params:
form_data["temperature"] = params["temperature"]
if "top_p" in params:
form_data["top_p"] = params["top_p"]
if "frequency_penalty" in params:
form_data["frequency_penalty"] = params["frequency_penalty"]
return form_data
async def process_chat_payload(request, form_data, metadata, user, model):
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
event_emitter = get_event_emitter(metadata)
event_call = get_event_call(metadata)
extra_params = {
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
"__event_emitter__": event_emitter,
"__event_call__": event_call,
"__user__": {
"id": user.id,
"email": user.email,
@@ -388,18 +588,70 @@ async def process_chat_payload(request, form_data, user, model):
# Initialize events to store additional event to be sent to the client
# Initialize contexts and citation
models = request.app.state.MODELS
events = []
sources = []
user_message = get_last_user_message(form_data["messages"])
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
if model_knowledge:
await event_emitter(
{
"type": "status",
"data": {
"action": "knowledge_search",
"query": user_message,
"done": False,
},
}
)
knowledge_files = []
for item in model_knowledge:
if item.get("collection_name"):
knowledge_files.append(
{
"id": item.get("collection_name"),
"name": item.get("name"),
"legacy": True,
}
)
elif item.get("collection_names"):
knowledge_files.append(
{
"name": item.get("name"),
"type": "collection",
"collection_names": item.get("collection_names"),
"legacy": True,
}
)
else:
knowledge_files.append(item)
files = form_data.get("files", [])
files.extend(knowledge_files)
form_data["files"] = files
features = form_data.pop("features", None)
if features:
if "web_search" in features and features["web_search"]:
form_data = await chat_web_search_handler(
request, form_data, extra_params, user
)
try:
form_data, flags = await chat_completion_filter_functions_handler(
request, form_data, model, extra_params
)
except Exception as e:
return Exception(f"Error: {e}")
raise Exception(f"Error: {e}")
tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None)
# Remove files duplicates
if files:
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
metadata = {
**metadata,
@@ -478,31 +730,366 @@ async def process_chat_payload(request, form_data, user, model):
if len(sources) > 0:
events.append({"sources": sources})
if model_knowledge:
await event_emitter(
{
"type": "status",
"data": {
"action": "knowledge_search",
"query": user_message,
"done": True,
"hidden": True,
},
}
)
return form_data, events
async def process_chat_response(response, events):
async def process_chat_response(
request, response, form_data, user, events, metadata, tasks
):
async def background_tasks_handler():
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
message = message_map.get(metadata["message_id"]) if message_map else None
if message:
messages = get_message_list(message_map, message.get("id"))
if tasks:
if TASKS.TITLE_GENERATION in tasks:
if tasks[TASKS.TITLE_GENERATION]:
res = await generate_title(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res and isinstance(res, dict):
title = (
res.get("choices", [])[0]
.get("message", {})
.get(
"content",
message.get("content", "New Chat"),
)
).strip()
if not title:
title = messages[0].get("content", "New Chat")
Chats.update_chat_title_by_id(metadata["chat_id"], title)
await event_emitter(
{
"type": "chat:title",
"data": title,
}
)
elif len(messages) == 2:
title = messages[0].get("content", "New Chat")
Chats.update_chat_title_by_id(metadata["chat_id"], title)
await event_emitter(
{
"type": "chat:title",
"data": message.get("content", "New Chat"),
}
)
if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
res = await generate_chat_tags(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res and isinstance(res, dict):
tags_string = (
res.get("choices", [])[0]
.get("message", {})
.get("content", "")
)
tags_string = tags_string[
tags_string.find("{") : tags_string.rfind("}") + 1
]
try:
tags = json.loads(tags_string).get("tags", [])
Chats.update_chat_tags_by_id(
metadata["chat_id"], tags, user
)
await event_emitter(
{
"type": "chat:tags",
"data": tags,
}
)
except Exception as e:
print(f"Error: {e}")
event_emitter = None
if (
"session_id" in metadata
and metadata["session_id"]
and "chat_id" in metadata
and metadata["chat_id"]
and "message_id" in metadata
and metadata["message_id"]
):
event_emitter = get_event_emitter(metadata)
if not isinstance(response, StreamingResponse):
if event_emitter:
if "selected_model_id" in response:
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"selectedModelId": response["selected_model_id"],
},
)
if response.get("choices", [])[0].get("message", {}).get("content"):
content = response["choices"][0]["message"]["content"]
if content:
await event_emitter(
{
"type": "chat:completion",
"data": response,
}
)
title = Chats.get_chat_title_by_id(metadata["chat_id"])
await event_emitter(
{
"type": "chat:completion",
"data": {
"done": True,
"content": content,
"title": title,
},
}
)
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": content,
},
)
# Send a webhook notification if the user is not active
if get_active_status_by_user_id(user.id) is None:
webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url:
post_webhook(
webhook_url,
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
{
"action": "chat",
"message": content,
"title": title,
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
},
)
await background_tasks_handler()
return response
else:
return response
if not any(
content_type in response.headers["Content-Type"]
for content_type in ["text/event-stream", "application/x-ndjson"]
):
return response
content_type = response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type
if event_emitter:
if not is_openai and not is_ollama:
return response
task_id = str(uuid4()) # Create a unique task ID.
async def stream_wrapper(original_generator, events):
def wrap_item(item):
return f"data: {item}\n\n" if is_openai else f"{item}\n"
# Handle as a background task
async def post_response_handler(response, events):
message = Chats.get_message_by_id_and_message_id(
metadata["chat_id"], metadata["message_id"]
)
content = message.get("content", "") if message else ""
for event in events:
yield wrap_item(json.dumps(event))
try:
for event in events:
await event_emitter(
{
"type": "chat:completion",
"data": event,
}
)
async for data in original_generator:
yield data
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
**event,
},
)
return StreamingResponse(
stream_wrapper(response.body_iterator, events),
headers=dict(response.headers),
)
async for line in response.body_iterator:
line = line.decode("utf-8") if isinstance(line, bytes) else line
data = line
# Skip empty lines
if not data.strip():
continue
# "data: " is the prefix for each event
if not data.startswith("data: "):
continue
# Remove the prefix
data = data[len("data: ") :]
try:
data = json.loads(data)
if "selected_model_id" in data:
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"selectedModelId": data["selected_model_id"],
},
)
else:
value = (
data.get("choices", [])[0]
.get("delta", {})
.get("content")
)
if value:
content = f"{content}{value}"
if ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": content,
},
)
else:
data = {
"content": content,
}
await event_emitter(
{
"type": "chat:completion",
"data": data,
}
)
except Exception as e:
done = "data: [DONE]" in line
if done:
pass
else:
continue
title = Chats.get_chat_title_by_id(metadata["chat_id"])
data = {"done": True, "content": content, "title": title}
if not ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": content,
},
)
# Send a webhook notification if the user is not active
if get_active_status_by_user_id(user.id) is None:
webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url:
post_webhook(
webhook_url,
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
{
"action": "chat",
"message": content,
"title": title,
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
},
)
await event_emitter(
{
"type": "chat:completion",
"data": data,
}
)
await background_tasks_handler()
except asyncio.CancelledError:
print("Task was cancelled!")
await event_emitter({"type": "task-cancelled"})
if not ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": content,
},
)
if response.background is not None:
await response.background()
# background_tasks.add_task(post_response_handler, response, events)
task_id, _ = create_task(post_response_handler(response, events))
return {"status": True, "task_id": task_id}
else:
# Fallback to the original response
async def stream_wrapper(original_generator, events):
def wrap_item(item):
return f"data: {item}\n\n"
for event in events:
yield wrap_item(json.dumps(event))
async for data in original_generator:
yield data
return StreamingResponse(
stream_wrapper(response.body_iterator, events),
headers=dict(response.headers),
background=response.background,
)

View File

@@ -7,6 +7,34 @@ from pathlib import Path
from typing import Callable, Optional
def get_message_list(messages, message_id):
"""
Reconstructs a list of messages in order up to the specified message_id.
:param message_id: ID of the message to reconstruct the chain
:param messages: Message history dict containing all messages
:return: List of ordered messages starting from the root to the given message
"""
# Find the message by its id
current_message = messages.get(message_id)
if not current_message:
return f"Message ID {message_id} not found in the history."
# Reconstruct the chain by following the parentId links
message_list = []
while current_message:
message_list.insert(
0, current_message
) # Insert the message at the beginning of the list
parent_id = current_message["parentId"]
current_message = messages.get(parent_id) if parent_id else None
return message_list
def get_messages_content(messages: list[dict]) -> str:
return "\n".join(
[
@@ -40,6 +68,13 @@ def get_last_user_message(messages: list[dict]) -> Optional[str]:
return get_content_from_message(message)
def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]:
for message in reversed(messages):
if message["role"] == "assistant":
return message
return None
def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
for message in reversed(messages):
if message["role"] == "assistant":

View File

@@ -58,7 +58,6 @@ async def get_all_base_models(request: Request):
return models
@cached(ttl=3)
async def get_all_models(request):
models = await get_all_base_models(request)

View File

@@ -14,13 +14,16 @@ from starlette.responses import RedirectResponse
from open_webui.models.auths import Auths
from open_webui.models.users import Users
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm
from open_webui.config import (
DEFAULT_USER_ROLE,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_OAUTH_ROLE_MANAGEMENT,
ENABLE_OAUTH_GROUP_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_GROUPS_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
@@ -44,7 +47,9 @@ auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
@@ -119,6 +124,61 @@ class OAuthManager:
return role
def update_user_groups(self, user, user_data, default_permissions):
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
user_oauth_groups: list[str] = user_data.get(oauth_claim, list())
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
all_available_groups: list[GroupModel] = Groups.get_groups()
# Remove groups that user is no longer a part of
for group_model in user_current_groups:
if group_model.name not in user_oauth_groups:
# Remove group from user
user_ids = group_model.user_ids
user_ids = [i for i in user_ids if i != user.id]
# In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions
if not group_permissions:
group_permissions = default_permissions
update_form = GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
user_ids=user_ids,
)
Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False
)
# Add user to new groups
for group_model in all_available_groups:
if group_model.name in user_oauth_groups and not any(
gm.name == group_model.name for gm in user_current_groups
):
# Add user to group
user_ids = group_model.user_ids
user_ids.append(user.id)
# In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions
if not group_permissions:
group_permissions = default_permissions
update_form = GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
user_ids=user_ids,
)
Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False
)
async def handle_login(self, provider, request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
@@ -254,6 +314,13 @@ class OAuthManager:
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
)
if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT:
self.update_user_groups(
user=user,
user_data=user_data,
default_permissions=request.app.state.config.USER_PERMISSIONS,
)
# Set the cookie token
response.set_cookie(
key="token",

View File

@@ -154,9 +154,16 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
)
ollama_payload["stream"] = openai_payload.get("stream", False)
if "format" in openai_payload:
ollama_payload["format"] = openai_payload["format"]
# If there are advanced parameters in the payload, format them in Ollama's options field
ollama_options = {}
if openai_payload.get("options"):
ollama_payload["options"] = openai_payload["options"]
ollama_options = openai_payload["options"]
# Handle parameters which map directly
for param in ["temperature", "top_p", "seed"]:
if param in openai_payload:

View File

@@ -29,7 +29,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
(
(
data.get("eval_count", 0)
/ ((data.get("eval_duration", 0) / 1_000_000_000))
/ ((data.get("eval_duration", 0) / 10_000_000))
)
* 100
),
@@ -43,12 +43,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
(
(
data.get("prompt_eval_count", 0)
/ (
(
data.get("prompt_eval_duration", 0)
/ 1_000_000_000
)
)
/ ((data.get("prompt_eval_duration", 0) / 10_000_000))
)
* 100
),
@@ -57,20 +52,12 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
if data.get("prompt_eval_duration", 0) > 0
else "N/A"
),
"total_duration": round(
((data.get("total_duration", 0) / 1_000_000) * 100), 2
),
"load_duration": round(
((data.get("load_duration", 0) / 1_000_000) * 100), 2
),
"total_duration": data.get("total_duration", 0),
"load_duration": data.get("load_duration", 0),
"prompt_eval_count": data.get("prompt_eval_count", 0),
"prompt_eval_duration": round(
((data.get("prompt_eval_duration", 0) / 1_000_000) * 100), 2
),
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
"eval_count": data.get("eval_count", 0),
"eval_duration": round(
((data.get("eval_duration", 0) / 1_000_000) * 100), 2
),
"eval_duration": data.get("eval_duration", 0),
"approximate_total": (
lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s"
)((data.get("total_duration", 0) or 0) // 1_000_000_000),

View File

@@ -11,6 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
def post_webhook(url: str, message: str, event_data: dict) -> bool:
try:
log.debug(f"post_webhook: {url}, {message}, {event_data}")
payload = {}
# Slack and Google Chat Webhooks
@@ -18,7 +19,11 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
payload["text"] = message
# Discord Webhooks
elif "https://discord.com/api/webhooks" in url:
payload["content"] = message
payload["content"] = (
message
if len(message) < 2000
else f"{message[: 2000 - 20]}... (truncated)"
)
# Microsoft Teams Webhooks
elif "webhook.office.com" in url:
action = event_data.get("action", "undefined")

View File

@@ -3,7 +3,7 @@ uvicorn[standard]==0.30.6
pydantic==2.9.2
python-multipart==0.0.18
Flask==3.0.3
Flask==3.1.0
Flask-Cors==5.0.0
python-socketio==5.11.3
@@ -18,7 +18,7 @@ aiofiles
sqlalchemy==2.0.32
alembic==1.14.0
peewee==3.17.6
peewee==3.17.8
peewee-migrate==1.12.2
psycopg2-binary==2.9.9
pgvector==0.3.5
@@ -55,7 +55,7 @@ einops==0.8.0
ftfy==6.2.3
pypdf==4.3.1
fpdf2==2.7.9
fpdf2==2.8.2
pymdown-extensions==10.11.2
docx2txt==0.8
python-pptx==1.0.0
@@ -67,7 +67,7 @@ pandas==2.2.3
openpyxl==3.1.5
pyxlsb==1.0.10
xlrd==2.0.1
validators==0.33.0
validators==0.34.0
psutil
sentencepiece
soundfile==0.12.1
@@ -78,7 +78,7 @@ rank-bm25==0.2.2
faster-whisper==1.0.3
PyJWT[crypto]==2.9.0
PyJWT[crypto]==2.10.1
authlib==1.3.2
black==24.8.0
@@ -90,6 +90,11 @@ extract_msg
pydub
duckduckgo-search~=6.3.5
## Google Drive
google-api-python-client
google-auth-httplib2
google-auth-oauthlib
## Tests
docker~=7.1.0
pytest~=8.3.2