Merge branch 'dev' into k_reranker

This commit is contained in:
Timothy Jaeryang Baek
2025-03-26 20:50:31 -07:00
committed by GitHub
147 changed files with 6065 additions and 1350 deletions

View File

@@ -3,6 +3,7 @@ import logging
import os
import shutil
import base64
import redis
from datetime import datetime
from pathlib import Path
@@ -17,6 +18,7 @@ from open_webui.env import (
DATA_DIR,
DATABASE_URL,
ENV,
REDIS_URL,
FRONTEND_BUILD_DIR,
OFFLINE_MODE,
OPEN_WEBUI_DIR,
@@ -248,9 +250,14 @@ class PersistentConfig(Generic[T]):
class AppConfig:
_state: dict[str, PersistentConfig]
_redis: Optional[redis.Redis] = None
def __init__(self):
def __init__(self, redis_url: Optional[str] = None):
super().__setattr__("_state", {})
if redis_url:
super().__setattr__(
"_redis", redis.Redis.from_url(redis_url, decode_responses=True)
)
def __setattr__(self, key, value):
if isinstance(value, PersistentConfig):
@@ -259,7 +266,31 @@ class AppConfig:
self._state[key].value = value
self._state[key].save()
if self._redis:
redis_key = f"open-webui:config:{key}"
self._redis.set(redis_key, json.dumps(self._state[key].value))
def __getattr__(self, key):
if key not in self._state:
raise AttributeError(f"Config key '{key}' not found")
# If Redis is available, check for an updated value
if self._redis:
redis_key = f"open-webui:config:{key}"
redis_value = self._redis.get(redis_key)
if redis_value is not None:
try:
decoded_value = json.loads(redis_value)
# Update the in-memory value if different
if self._state[key].value != decoded_value:
self._state[key].value = decoded_value
log.info(f"Updated {key} from Redis: {decoded_value}")
except json.JSONDecodeError:
log.error(f"Invalid JSON format in Redis for {key}: {redis_value}")
return self._state[key].value
@@ -1276,7 +1307,7 @@ Strictly return in JSON format:
ENABLE_AUTOCOMPLETE_GENERATION = PersistentConfig(
"ENABLE_AUTOCOMPLETE_GENERATION",
"task.autocomplete.enable",
os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "True").lower() == "true",
os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "False").lower() == "true",
)
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = PersistentConfig(
@@ -1548,8 +1579,10 @@ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
# OpenSearch
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True)
OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", "true").lower() == "true"
OPENSEARCH_CERT_VERIFY = (
os.environ.get("OPENSEARCH_CERT_VERIFY", "false").lower() == "true"
)
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
@@ -1623,6 +1656,12 @@ TIKA_SERVER_URL = PersistentConfig(
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
)
DOCLING_SERVER_URL = PersistentConfig(
"DOCLING_SERVER_URL",
"rag.docling_server_url",
os.getenv("DOCLING_SERVER_URL", "http://docling:5001"),
)
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
"DOCUMENT_INTELLIGENCE_ENDPOINT",
"rag.document_intelligence_endpoint",
@@ -1955,6 +1994,12 @@ TAVILY_API_KEY = PersistentConfig(
os.getenv("TAVILY_API_KEY", ""),
)
TAVILY_EXTRACT_DEPTH = PersistentConfig(
"TAVILY_EXTRACT_DEPTH",
"rag.web.search.tavily_extract_depth",
os.getenv("TAVILY_EXTRACT_DEPTH", "basic"),
)
JINA_API_KEY = PersistentConfig(
"JINA_API_KEY",
"rag.web.search.jina_api_key",
@@ -2041,6 +2086,12 @@ PLAYWRIGHT_WS_URI = PersistentConfig(
os.environ.get("PLAYWRIGHT_WS_URI", None),
)
PLAYWRIGHT_TIMEOUT = PersistentConfig(
"PLAYWRIGHT_TIMEOUT",
"rag.web.loader.engine.playwright.timeout",
int(os.environ.get("PLAYWRIGHT_TIMEOUT", "10")),
)
FIRECRAWL_API_KEY = PersistentConfig(
"FIRECRAWL_API_KEY",
"firecrawl.api_key",

View File

@@ -105,7 +105,6 @@ for source in log_sources:
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
@@ -130,7 +129,6 @@ else:
except Exception:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
@@ -161,7 +159,6 @@ try:
except Exception:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
@@ -192,7 +189,6 @@ for version in soup.find_all("h2"):
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
####################################
@@ -209,7 +205,6 @@ ENABLE_FORWARD_USER_INFO_HEADERS = (
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
)
####################################
# WEBUI_BUILD_HASH
####################################
@@ -244,7 +239,6 @@ if FROM_INIT_PY:
DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
@@ -256,7 +250,6 @@ if FROM_INIT_PY:
os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
).resolve()
####################################
# Database
####################################
@@ -321,7 +314,6 @@ RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
ENABLE_REALTIME_CHAT_SAVE = (
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
)
@@ -330,7 +322,7 @@ ENABLE_REALTIME_CHAT_SAVE = (
# REDIS
####################################
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
REDIS_URL = os.environ.get("REDIS_URL", "")
####################################
# WEBUI_AUTH (Required for security)
@@ -399,18 +391,16 @@ else:
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST",
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""),
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "10"),
)
if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "":
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None
else:
try:
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = int(AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
except Exception:
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 5
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 10
####################################
# OFFLINE_MODE
@@ -424,13 +414,12 @@ if OFFLINE_MODE:
####################################
# AUDIT LOGGING
####################################
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
# Where to store log file
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try:
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
except ValueError:
@@ -442,3 +431,26 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders"
)
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
####################################
# OPENTELEMETRY
####################################
ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", ""
) # e.g. key1=val1,key2=val2
OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower()
####################################
# TOOLS/FUNCTIONS PIP OPTIONS
####################################
PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split()
PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()

View File

@@ -223,6 +223,9 @@ async def generate_function_chat_completion(
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__chat_id__": metadata.get("chat_id", None),
"__session_id__": metadata.get("session_id", None),
"__message_id__": metadata.get("message_id", None),
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,

View File

@@ -84,7 +84,7 @@ from open_webui.routers.retrieval import (
get_rf,
)
from open_webui.internal.db import Session
from open_webui.internal.db import Session, engine
from open_webui.models.functions import Functions
from open_webui.models.models import Models
@@ -155,6 +155,7 @@ from open_webui.config import (
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
PLAYWRIGHT_WS_URI,
PLAYWRIGHT_TIMEOUT,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
RAG_WEB_LOADER_ENGINE,
@@ -186,6 +187,7 @@ from open_webui.config import (
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL,
DOCLING_SERVER_URL,
DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY,
RAG_TOP_K,
@@ -213,6 +215,7 @@ from open_webui.config import (
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
TAVILY_API_KEY,
TAVILY_EXTRACT_DEPTH,
BING_SEARCH_V7_ENDPOINT,
BING_SEARCH_V7_SUBSCRIPTION_KEY,
BRAVE_SEARCH_API_KEY,
@@ -313,6 +316,7 @@ from open_webui.env import (
AUDIT_EXCLUDED_PATHS,
AUDIT_LOG_LEVEL,
CHANGELOG,
REDIS_URL,
GLOBAL_LOG_LEVEL,
MAX_BODY_LOG_SIZE,
SAFE_MODE,
@@ -328,6 +332,7 @@ from open_webui.env import (
BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START,
OFFLINE_MODE,
ENABLE_OTEL,
)
@@ -355,7 +360,6 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
@@ -419,11 +423,24 @@ app = FastAPI(
oauth_manager = OAuthManager(app)
app.state.config = AppConfig()
app.state.config = AppConfig(redis_url=REDIS_URL)
app.state.WEBUI_NAME = WEBUI_NAME
app.state.LICENSE_METADATA = None
########################################
#
# OPENTELEMETRY
#
########################################
if ENABLE_OTEL:
from open_webui.utils.telemetry.setup import setup as setup_opentelemetry
setup_opentelemetry(app=app, db_engine=engine)
########################################
#
# OLLAMA
@@ -551,6 +568,7 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
@@ -614,8 +632,10 @@ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_
app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT
app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
app.state.EMBEDDING_FUNCTION = None
app.state.ef = None
@@ -949,14 +969,24 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models
models = await get_all_models(request, user=user)
all_models = await get_all_models(request, user=user)
# Filter out filter pipelines
models = [
model
for model in models
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
]
models = []
for model in all_models:
# Filter out filter pipelines
if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
continue
model_tags = [
tag.get("name")
for tag in model.get("info", {}).get("meta", {}).get("tags", [])
]
tags = [tag.get("name") for tag in model.get("tags", [])]
tags = list(set(model_tags + tags))
model["tags"] = [{"name": tag} for tag in tags]
models.append(model)
model_order_list = request.app.state.config.MODEL_ORDER_LIST
if model_order_list:

View File

@@ -105,7 +105,7 @@ class TikaLoader:
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
@@ -117,6 +117,52 @@ class TikaLoader:
raise Exception(f"Error calling Tika: {r.reason}")
class DoclingLoader:
def __init__(self, url, file_path=None, mime_type=None):
self.url = url.rstrip("/")
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
files = {
"files": (
self.file_path,
f,
self.mime_type or "application/octet-stream",
)
}
params = {
"image_export_mode": "placeholder",
"table_mode": "accurate",
}
endpoint = f"{self.url}/v1alpha/convert/file"
r = requests.post(endpoint, files=files, data=params)
if r.ok:
result = r.json()
document_data = result.get("document", {})
text = document_data.get("md_content", "<No text content found>")
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
log.debug("Docling extracted text: %s", text)
return [Document(page_content=text, metadata=metadata)]
else:
error_msg = f"Error calling Docling API: {r.reason}"
if r.text:
try:
error_data = r.json()
if "detail" in error_data:
error_msg += f" - {error_data['detail']}"
except Exception:
error_msg += f" - {r.text}"
raise Exception(f"Error calling Docling: {error_msg}")
class Loader:
def __init__(self, engine: str = "", **kwargs):
self.engine = engine
@@ -149,6 +195,12 @@ class Loader:
file_path=file_path,
mime_type=file_content_type,
)
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
)
elif (
self.engine == "document_intelligence"
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""

View File

@@ -0,0 +1,93 @@
import requests
import logging
from typing import Iterator, List, Literal, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class TavilyLoader(BaseLoader):
"""Extract web page content from URLs using Tavily Extract API.
This is a LangChain document loader that uses Tavily's Extract API to
retrieve content from web pages and return it as Document objects.
Args:
urls: URL or list of URLs to extract content from.
api_key: The Tavily API key.
extract_depth: Depth of extraction, either "basic" or "advanced".
continue_on_failure: Whether to continue if extraction of a URL fails.
"""
def __init__(
self,
urls: Union[str, List[str]],
api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic",
continue_on_failure: bool = True,
) -> None:
"""Initialize Tavily Extract client.
Args:
urls: URL or list of URLs to extract content from.
api_key: The Tavily API key.
include_images: Whether to include images in the extraction.
extract_depth: Depth of extraction, either "basic" or "advanced".
advanced extraction retrieves more data, including tables and
embedded content, with higher success but may increase latency.
basic costs 1 credit per 5 successful URL extractions,
advanced costs 2 credits per 5 successful URL extractions.
continue_on_failure: Whether to continue if extraction of a URL fails.
"""
if not urls:
raise ValueError("At least one URL must be provided.")
self.api_key = api_key
self.urls = urls if isinstance(urls, list) else [urls]
self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure
self.api_url = "https://api.tavily.com/extract"
def lazy_load(self) -> Iterator[Document]:
"""Extract and yield documents from the URLs using Tavily Extract API."""
batch_size = 20
for i in range(0, len(self.urls), batch_size):
batch_urls = self.urls[i : i + batch_size]
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
# Use string for single URL, array for multiple URLs
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
# Make the API call
response = requests.post(self.api_url, headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()
# Process successful results
for result in response_data.get("results", []):
url = result.get("url", "")
content = result.get("raw_content", "")
if not content:
log.warning(f"No content extracted from {url}")
continue
# Add URLs as metadata
metadata = {"source": url}
yield Document(
page_content=content,
metadata=metadata,
)
for failed in response_data.get("failed_results", []):
url = failed.get("url", "")
error = failed.get("error", "Unknown error")
log.error(f"Failed to extract content from {url}: {error}")
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {batch_urls}: {e}")
else:
raise e

View File

@@ -189,8 +189,7 @@ def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> dict:
# Initialize lists to store combined data
combined = []
seen_hashes = set() # To store unique document hashes
combined = dict() # To store documents with unique document hashes
for data in query_results:
distances = data["distances"][0]
@@ -203,10 +202,19 @@ def merge_and_sort_query_results(
document.encode()
).hexdigest() # Compute a hash for uniqueness
if doc_hash not in seen_hashes:
seen_hashes.add(doc_hash)
combined.append((distance, document, metadata))
if doc_hash not in combined.keys():
combined[doc_hash] = (distance, document, metadata)
continue # if doc is new, no further comparison is needed
# if doc is alredy in, but new distance is better, update
if not reverse and distance < combined[doc_hash][0]:
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
combined[doc_hash] = (distance, document, metadata)
if reverse and distance > combined[doc_hash][0]:
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
@@ -215,6 +223,12 @@ def merge_and_sort_query_results(
zip(*combined[:k]) if combined else ([], [], [])
)
# if chromaDB, the distance is 0 (best) to 2 (worse)
# re-order to -1 (worst) to 1 (best) for relevance score
if not reverse:
sorted_distances = tuple(-dist for dist in sorted_distances)
sorted_distances = tuple(dist + 1 for dist in sorted_distances)
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
@@ -306,13 +320,8 @@ def query_collection_with_hybrid_search(
raise Exception(
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
)
if VECTOR_DB == "chroma":
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
return merge_and_sort_query_results(results, k=k_reranker, reverse=False)
else:
return merge_and_sort_query_results(results, k=k_reranker, reverse=True)
return merge_and_sort_query_results(results, k=k, reverse=True)
def get_embedding_function(

View File

@@ -166,12 +166,19 @@ class ChromaClient:
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
collection = self.client.get_collection(name=collection_name)
if collection:
if ids:
collection.delete(ids=ids)
elif filter:
collection.delete(where=filter)
try:
collection = self.client.get_collection(name=collection_name)
if collection:
if ids:
collection.delete(ids=ids)
elif filter:
collection.delete(where=filter)
except Exception as e:
# If collection doesn't exist, that's fine - nothing to delete
log.debug(
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
)
pass
def reset(self):
# Resets the database. This will delete all collections and item entries.

View File

@@ -1,4 +1,5 @@
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
@@ -21,7 +22,13 @@ class OpenSearchClient:
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
)
def _get_index_name(self, collection_name: str) -> str:
return f"{self.index_prefix}_{collection_name}"
def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
return None
ids = []
documents = []
metadatas = []
@@ -31,9 +38,12 @@ class OpenSearchClient:
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
def _result_to_search_result(self, result) -> SearchResult:
if not result["hits"]["hits"]:
return None
ids = []
distances = []
documents = []
@@ -46,34 +56,40 @@ class OpenSearchClient:
metadatas.append(hit["_source"].get("metadata"))
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
ids=[ids],
distances=[distances],
documents=[documents],
metadatas=[metadatas],
)
def _create_index(self, collection_name: str, dimension: int):
body = {
"settings": {"index": {"knn": True}},
"mappings": {
"properties": {
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": true,
"type": "knn_vector",
"dimension": dimension, # Adjust based on your vector dimensions
"index": True,
"similarity": "faiss",
"method": {
"name": "hnsw",
"space_type": "ip", # Use inner product to approximate cosine similarity
"space_type": "innerproduct", # Use inner product to approximate cosine similarity
"engine": "faiss",
"ef_construction": 128,
"m": 16,
"parameters": {
"ef_construction": 128,
"m": 16,
},
},
},
"text": {"type": "text"},
"metadata": {"type": "object"},
}
}
},
}
self.client.indices.create(
index=f"{self.index_prefix}_{collection_name}", body=body
index=self._get_index_name(collection_name), body=body
)
def _create_batches(self, items: list[VectorItem], batch_size=100):
@@ -83,39 +99,45 @@ class OpenSearchClient:
def has_collection(self, collection_name: str) -> bool:
# has_collection here means has index.
# We are simply adapting to the norms of the other DBs.
return self.client.indices.exists(
index=f"{self.index_prefix}_{collection_name}"
)
return self.client.indices.exists(index=self._get_index_name(collection_name))
def delete_colleciton(self, collection_name: str):
def delete_collection(self, collection_name: str):
# delete_collection here means delete index.
# We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=f"{self.index_prefix}_{collection_name}")
self.client.indices.delete(index=self._get_index_name(collection_name))
def search(
self, collection_name: str, vectors: list[list[float]], limit: int
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
},
}
},
}
try:
if not self.has_collection(collection_name):
return None
result = self.client.search(
index=f"{self.index_prefix}_{collection_name}", body=query
)
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0",
"params": {
"field": "vector",
"query_value": vectors[0],
}, # Assuming single query vector
},
}
},
}
return self._result_to_search_result(result)
result = self.client.search(
index=self._get_index_name(collection_name), body=query
)
return self._result_to_search_result(result)
except Exception as e:
return None
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
@@ -129,13 +151,15 @@ class OpenSearchClient:
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
query_body["query"]["bool"]["filter"].append(
{"match": {"metadata." + str(field): value}}
)
size = limit if limit else 10
try:
result = self.client.search(
index=f"{self.index_prefix}_{collection_name}",
index=self._get_index_name(collection_name),
body=query_body,
size=size,
)
@@ -146,14 +170,14 @@ class OpenSearchClient:
return None
def _create_index_if_not_exists(self, collection_name: str, dimension: int):
if not self.has_index(collection_name):
if not self.has_collection(collection_name):
self._create_index(collection_name, dimension)
def get(self, collection_name: str) -> Optional[GetResult]:
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
result = self.client.search(
index=f"{self.index_prefix}_{collection_name}", body=query
index=self._get_index_name(collection_name), body=query
)
return self._result_to_get_result(result)
@@ -165,18 +189,18 @@ class OpenSearchClient:
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
"_op_type": "index",
"_index": self._get_index_name(collection_name),
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
for item in batch
]
self.client.bulk(actions)
bulk(self.client, actions)
def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
@@ -186,26 +210,47 @@ class OpenSearchClient:
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_index": f"{self.index_prefix}_{collection_name}",
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
"_op_type": "update",
"_index": self._get_index_name(collection_name),
"_id": item["id"],
"doc": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
"doc_as_upsert": True,
}
for item in batch
]
self.client.bulk(actions)
bulk(self.client, actions)
def delete(self, collection_name: str, ids: list[str]):
actions = [
{"delete": {"_index": f"{self.index_prefix}_{collection_name}", "_id": id}}
for id in ids
]
self.client.bulk(body=actions)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
if ids:
actions = [
{
"_op_type": "delete",
"_index": self._get_index_name(collection_name),
"_id": id,
}
for id in ids
]
bulk(self.client, actions)
elif filter:
query_body = {
"query": {"bool": {"filter": []}},
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
{"match": {"metadata." + str(field): value}}
)
self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body
)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")

View File

@@ -24,13 +24,17 @@ from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoa
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from open_webui.retrieval.loaders.tavily import TavilyLoader
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
ENABLE_RAG_LOCAL_WEB_FETCH,
PLAYWRIGHT_WS_URI,
PLAYWRIGHT_TIMEOUT,
RAG_WEB_LOADER_ENGINE,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
TAVILY_API_KEY,
TAVILY_EXTRACT_DEPTH,
)
from open_webui.env import SRC_LOG_LEVELS
@@ -113,7 +117,47 @@ def verify_ssl_cert(url: str) -> bool:
return False
class SafeFireCrawlLoader(BaseLoader):
class RateLimitMixin:
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
class URLProcessingMixin:
def _verify_ssl_cert(self, url: str) -> bool:
"""Verify SSL certificate for a URL."""
return verify_ssl_cert(url)
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
def __init__(
self,
web_paths,
@@ -184,7 +228,7 @@ class SafeFireCrawlLoader(BaseLoader):
yield from loader.lazy_load()
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
log.exception(f"Error loading {url}: {e}")
continue
raise e
@@ -204,47 +248,124 @@ class SafeFireCrawlLoader(BaseLoader):
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
log.exception(f"Error loading {url}: {e}")
continue
raise e
def _verify_ssl_cert(self, url: str) -> bool:
return verify_ssl_cert(url)
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
def __init__(
self,
web_paths: Union[str, List[str]],
api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic",
continue_on_failure: bool = True,
requests_per_second: Optional[float] = None,
verify_ssl: bool = True,
trust_env: bool = False,
proxy: Optional[Dict[str, str]] = None,
):
"""Initialize SafeTavilyLoader with rate limiting and SSL verification support.
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
Args:
web_paths: List of URLs/paths to process.
api_key: The Tavily API key.
extract_depth: Depth of extraction ("basic" or "advanced").
continue_on_failure: Whether to continue if extraction of a URL fails.
requests_per_second: Number of requests per second to limit to.
verify_ssl: If True, verify SSL certificates.
trust_env: If True, use proxy settings from environment variables.
proxy: Optional proxy configuration.
"""
# Initialize proxy configuration if using environment variables
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
# Store parameters for creating TavilyLoader instances
self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
self.api_key = api_key
self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure
self.verify_ssl = verify_ssl
self.trust_env = trust_env
self.proxy = proxy
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
# Add rate limiting
self.requests_per_second = requests_per_second
self.last_request_time = None
def lazy_load(self) -> Iterator[Document]:
"""Load documents with rate limiting support, delegating to TavilyLoader."""
valid_urls = []
for url in self.web_paths:
try:
self._safe_process_url_sync(url)
valid_urls.append(url)
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
if not self.continue_on_failure:
raise e
if not valid_urls:
if self.continue_on_failure:
log.warning("No valid URLs to process after SSL verification")
return
raise ValueError("No valid URLs to process after SSL verification")
try:
loader = TavilyLoader(
urls=valid_urls,
api_key=self.api_key,
extract_depth=self.extract_depth,
continue_on_failure=self.continue_on_failure,
)
yield from loader.lazy_load()
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error extracting content from URLs: {e}")
else:
raise e
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async version with rate limiting and SSL verification."""
valid_urls = []
for url in self.web_paths:
try:
await self._safe_process_url(url)
valid_urls.append(url)
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
if not self.continue_on_failure:
raise e
if not valid_urls:
if self.continue_on_failure:
log.warning("No valid URLs to process after SSL verification")
return
raise ValueError("No valid URLs to process after SSL verification")
try:
loader = TavilyLoader(
urls=valid_urls,
api_key=self.api_key,
extract_depth=self.extract_depth,
continue_on_failure=self.continue_on_failure,
)
async for document in loader.alazy_load():
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading URLs: {e}")
else:
raise e
class SafePlaywrightURLLoader(PlaywrightURLLoader):
class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessingMixin):
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
Attributes:
@@ -256,6 +377,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
headless (bool): If True, the browser will run in headless mode.
proxy (dict): Proxy override settings for the Playwright session.
playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
playwright_timeout (Optional[int]): Maximum operation time in milliseconds.
"""
def __init__(
@@ -269,6 +391,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
remove_selectors: Optional[List[str]] = None,
proxy: Optional[Dict[str, str]] = None,
playwright_ws_url: Optional[str] = None,
playwright_timeout: Optional[int] = 10000,
):
"""Initialize with additional safety parameters and remote browser support."""
@@ -295,6 +418,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
self.last_request_time = None
self.playwright_ws_url = playwright_ws_url
self.trust_env = trust_env
self.playwright_timeout = playwright_timeout
def lazy_load(self) -> Iterator[Document]:
"""Safely load URLs synchronously with support for remote browser."""
@@ -311,7 +435,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
try:
self._safe_process_url_sync(url)
page = browser.new_page()
response = page.goto(url)
response = page.goto(url, timeout=self.playwright_timeout)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
@@ -320,7 +444,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
log.exception(f"Error loading {url}: {e}")
continue
raise e
browser.close()
@@ -342,7 +466,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
try:
await self._safe_process_url(url)
page = await browser.new_page()
response = await page.goto(url)
response = await page.goto(url, timeout=self.playwright_timeout)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
@@ -351,46 +475,11 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
log.exception(f"Error loading {url}: {e}")
continue
raise e
await browser.close()
def _verify_ssl_cert(self, url: str) -> bool:
return verify_ssl_cert(url)
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
@@ -472,7 +561,7 @@ class SafeWebBaseLoader(WebBaseLoader):
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.exception(e, "Error loading %s", path)
log.exception(f"Error loading {path}: {e}")
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async lazy load text from the url(s) in web_path."""
@@ -499,6 +588,7 @@ RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
RAG_WEB_LOADER_ENGINES["tavily"] = SafeTavilyLoader
def get_web_loader(
@@ -518,13 +608,19 @@ def get_web_loader(
"trust_env": trust_env,
}
if PLAYWRIGHT_WS_URI.value:
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
if RAG_WEB_LOADER_ENGINE.value == "playwright":
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value * 1000
if PLAYWRIGHT_WS_URI.value:
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
if RAG_WEB_LOADER_ENGINE.value == "firecrawl":
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
if RAG_WEB_LOADER_ENGINE.value == "tavily":
web_loader_args["api_key"] = TAVILY_API_KEY.value
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
# Create the appropriate WebLoader based on the configuration
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
web_loader = WebLoaderClass(**web_loader_args)

View File

@@ -625,7 +625,9 @@ def transcription(
):
log.info(f"file.content_type: {file.content_type}")
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
supported_filetypes = ("audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a")
if not file.content_type.startswith(supported_filetypes):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,

View File

@@ -210,7 +210,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
LDAP_APP_DN,
LDAP_APP_PASSWORD,
auto_bind="NONE",
authentication="SIMPLE",
authentication="SIMPLE" if LDAP_APP_DN else "ANONYMOUS",
)
if not connection_app.bind():
raise HTTPException(400, detail="Application account bind failed")

View File

@@ -2,6 +2,8 @@ import json
import logging
from typing import Optional
from open_webui.socket.main import get_event_emitter
from open_webui.models.chats import (
ChatForm,
ChatImportForm,
@@ -372,6 +374,107 @@ async def update_chat_by_id(
)
############################
# UpdateChatMessageById
############################
class MessageForm(BaseModel):
content: str
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
async def update_chat_message_by_id(
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(id)
if not chat:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if chat.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
chat = Chats.upsert_message_to_chat_by_id_and_message_id(
id,
message_id,
{
"content": form_data.content,
},
)
event_emitter = get_event_emitter(
{
"user_id": user.id,
"chat_id": id,
"message_id": message_id,
},
False,
)
if event_emitter:
await event_emitter(
{
"type": "chat:message",
"data": {
"chat_id": id,
"message_id": message_id,
"content": form_data.content,
},
}
)
return ChatResponse(**chat.model_dump())
############################
# SendChatMessageEventById
############################
class EventForm(BaseModel):
type: str
data: dict
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
async def send_chat_message_event_by_id(
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(id)
if not chat:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if chat.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
event_emitter = get_event_emitter(
{
"user_id": user.id,
"chat_id": id,
"message_id": message_id,
}
)
try:
if event_emitter:
await event_emitter(form_data.model_dump())
else:
return False
return True
except:
return False
############################
# DeleteChatById
############################

View File

@@ -81,7 +81,7 @@ def upload_file(
ProcessFileForm(file_id=id, content=result.get("text", "")),
user=user,
)
else:
elif file.content_type not in ["image/png", "image/jpeg", "image/gif"]:
process_file(request, ProcessFileForm(file_id=id), user=user)
file_item = Files.get_file_by_id(id=id)

View File

@@ -517,10 +517,8 @@ async def image_generations(
images = []
for image in res["data"]:
if "url" in image:
image_data, content_type = load_url_image_data(
image["url"], headers
)
if image_url := image.get("url", None):
image_data, content_type = load_url_image_data(image_url, headers)
else:
image_data, content_type = load_b64_image_data(image["b64_json"])

View File

@@ -437,14 +437,24 @@ def remove_file_from_knowledge_by_id(
)
# Remove content from the vector database
VECTOR_DB_CLIENT.delete(
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
)
try:
VECTOR_DB_CLIENT.delete(
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
)
except Exception as e:
log.debug("This was most likely caused by bypassing embedding processing")
log.debug(e)
pass
# Remove the file's collection from vector database
file_collection = f"file-{form_data.file_id}"
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
try:
# Remove the file's collection from vector database
file_collection = f"file-{form_data.file_id}"
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
except Exception as e:
log.debug("This was most likely caused by bypassing embedding processing")
log.debug(e)
pass
# Delete file from database
Files.delete_file_by_id(form_data.file_id)

View File

@@ -295,7 +295,7 @@ async def update_config(
}
@cached(ttl=3)
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API:
@@ -336,6 +336,7 @@ async def get_all_models(request: Request, user: UserModel = None):
)
prefix_id = api_config.get("prefix_id", None)
tags = api_config.get("tags", [])
model_ids = api_config.get("model_ids", [])
if len(model_ids) != 0 and "models" in response:
@@ -350,6 +351,10 @@ async def get_all_models(request: Request, user: UserModel = None):
for model in response.get("models", []):
model["model"] = f"{prefix_id}.{model['model']}"
if tags:
for model in response.get("models", []):
model["tags"] = tags
def merge_models_lists(model_lists):
merged_models = {}
@@ -1164,7 +1169,7 @@ async def generate_chat_completion(
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
# payload["keep_alive"] = -1 # keep alive forever
return await send_post_request(
url=f"{url}/api/chat",
payload=json.dumps(payload),

View File

@@ -36,6 +36,9 @@ from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from open_webui.utils.misc import (
convert_logit_bias_input_to_json,
)
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access
@@ -350,6 +353,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
)
prefix_id = api_config.get("prefix_id", None)
tags = api_config.get("tags", [])
if prefix_id:
for model in (
@@ -357,6 +361,12 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
):
model["id"] = f"{prefix_id}.{model['id']}"
if tags:
for model in (
response if isinstance(response, list) else response.get("data", [])
):
model["tags"] = tags
log.debug(f"get_all_models:responses() {responses}")
return responses
@@ -374,7 +384,7 @@ async def get_filtered_models(models, user):
return filtered_models
@cached(ttl=3)
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()")
@@ -396,6 +406,7 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{
@@ -406,18 +417,21 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
"urlIdx": idx,
}
for model in models
if "api.openai.com"
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
if (model.get("id") or model.get("name"))
and (
"api.openai.com"
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
)
]
)
@@ -666,6 +680,11 @@ async def generate_chat_completion(
del payload["max_tokens"]
# Convert the modified body back to JSON
if "logit_bias" in payload:
payload["logit_bias"] = json.loads(
convert_logit_bias_input_to_json(payload["logit_bias"])
)
payload = json.dumps(payload)
r = None

View File

@@ -90,8 +90,8 @@ async def process_pipeline_inlet_filter(request, payload, user, models):
headers=headers,
json=request_data,
) as response:
response.raise_for_status()
payload = await response.json()
response.raise_for_status()
except aiohttp.ClientResponseError as e:
res = (
await response.json()
@@ -139,8 +139,8 @@ async def process_pipeline_outlet_filter(request, payload, user, models):
headers=headers,
json=request_data,
) as response:
response.raise_for_status()
payload = await response.json()
response.raise_for_status()
except aiohttp.ClientResponseError as e:
try:
res = (

View File

@@ -358,6 +358,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"content_extraction": {
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
"document_intelligence_config": {
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
@@ -428,6 +429,7 @@ class DocumentIntelligenceConfigForm(BaseModel):
class ContentExtractionConfig(BaseModel):
engine: str = ""
tika_server_url: Optional[str] = None
docling_server_url: Optional[str] = None
document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
@@ -540,6 +542,9 @@ async def update_rag_config(
request.app.state.config.TIKA_SERVER_URL = (
form_data.content_extraction.tika_server_url
)
request.app.state.config.DOCLING_SERVER_URL = (
form_data.content_extraction.docling_server_url
)
if form_data.content_extraction.document_intelligence_config is not None:
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
form_data.content_extraction.document_intelligence_config.endpoint
@@ -648,6 +653,7 @@ async def update_rag_config(
"content_extraction": {
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
"document_intelligence_config": {
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
@@ -994,6 +1000,7 @@ def process_file(
loader = Loader(
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,

View File

@@ -2,6 +2,7 @@ import logging
from typing import Optional
from open_webui.models.auths import Auths
from open_webui.models.groups import Groups
from open_webui.models.chats import Chats
from open_webui.models.users import (
UserModel,
@@ -17,7 +18,10 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
from open_webui.utils.access_control import get_permissions
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -45,7 +49,7 @@ async def get_users(
@router.get("/groups")
async def get_user_groups(user=Depends(get_verified_user)):
return Users.get_user_groups(user.id)
return Groups.get_groups_by_member_id(user.id)
############################
@@ -54,8 +58,12 @@ async def get_user_groups(user=Depends(get_verified_user)):
@router.get("/permissions")
async def get_user_permissisions(user=Depends(get_verified_user)):
return Users.get_user_groups(user.id)
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return user_permissions
############################
@@ -89,7 +97,7 @@ class UserPermissions(BaseModel):
@router.get("/default/permissions", response_model=UserPermissions)
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
async def get_default_user_permissions(request: Request, user=Depends(get_admin_user)):
return {
"workspace": WorkspacePermissions(
**request.app.state.config.USER_PERMISSIONS.get("workspace", {})
@@ -104,7 +112,7 @@ async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
@router.post("/default/permissions")
async def update_user_permissions(
async def update_default_user_permissions(
request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
):
request.app.state.config.USER_PERMISSIONS = form_data.model_dump()

View File

@@ -269,11 +269,19 @@ async def disconnect(sid):
# print(f"Unknown session ID {sid} disconnected")
def get_event_emitter(request_info):
def get_event_emitter(request_info, update_db=True):
async def __event_emitter__(event_data):
user_id = request_info["user_id"]
session_ids = list(
set(USER_POOL.get(user_id, []) + [request_info["session_id"]])
set(
USER_POOL.get(user_id, [])
+ (
[request_info.get("session_id")]
if request_info.get("session_id")
else []
)
)
)
for session_id in session_ids:
@@ -287,40 +295,41 @@ def get_event_emitter(request_info):
to=session_id,
)
if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
event_data.get("data", {}),
)
if update_db:
if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
event_data.get("data", {}),
)
if "type" in event_data and event_data["type"] == "message":
message = Chats.get_message_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
)
if "type" in event_data and event_data["type"] == "message":
message = Chats.get_message_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
)
content = message.get("content", "")
content += event_data.get("data", {}).get("content", "")
content = message.get("content", "")
content += event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
if "type" in event_data and event_data["type"] == "replace":
content = event_data.get("data", {}).get("content", "")
if "type" in event_data and event_data["type"] == "replace":
content = event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
return __event_emitter__

View File

@@ -106,6 +106,7 @@ async def process_filter_functions(
# Handle file cleanup for inlet
if skip_files and "files" in form_data.get("metadata", {}):
del form_data["files"]
del form_data["metadata"]["files"]
return form_data, {}

View File

@@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, tools
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]:
content = None
@@ -135,6 +135,9 @@ async def chat_completion_tools_handler(
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
event_caller = extra_params["__event_call__"]
metadata = extra_params["__metadata__"]
task_model_id = get_task_model_id(
body["model"],
request.app.state.config.TASK_MODEL,
@@ -189,19 +192,33 @@ async def chat_completion_tools_handler(
tool_function_params = tool_call.get("parameters", {})
try:
required_params = (
tools[tool_function_name]
.get("spec", {})
.get("parameters", {})
.get("required", [])
tool = tools[tool_function_name]
spec = tool.get("spec", {})
allowed_params = (
spec.get("parameters", {}).get("properties", {}).keys()
)
tool_function = tools[tool_function_name]["callable"]
tool_function = tool["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in required_params
if k in allowed_params
}
tool_output = await tool_function(**tool_function_params)
if tool.get("direct", False):
tool_output = await tool_function(**tool_function_params)
else:
tool_output = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"params": tool_function_params,
"session_id": metadata.get("session_id", None),
},
}
)
except Exception as e:
tool_output = str(e)
@@ -767,12 +784,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
}
form_data["metadata"] = metadata
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
tool_specs = form_data.get("tool_specs", None)
log.debug(f"{tool_ids=}")
log.debug(f"{tool_specs=}")
tools_dict = {}
if tool_ids:
# If tool_ids field is present, then get the tools
tools = get_tools(
tools_dict = get_tools(
request,
tool_ids,
user,
@@ -783,20 +806,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
log.info(f"{tools_dict=}")
if tool_specs:
for tool in tool_specs:
callable = tool.pop("callable", None)
tools_dict[tool["name"]] = {
"direct": True,
"callable": callable,
"spec": tool,
}
if tools_dict:
if metadata.get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools
metadata["tools"] = tools_dict
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools.values()
for tool in tools_dict.values()
]
else:
# If the function calling is not native, then call the tools function calling handler
try:
form_data, flags = await chat_completion_tools_handler(
request, form_data, user, models, tools
request, form_data, extra_params, user, models, tools_dict
)
sources.extend(flags.get("sources", []))
@@ -815,7 +848,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
for source_idx, source in enumerate(sources):
if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]):
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
context_string += f"<source><source_id>{source_idx + 1}</source_id><source_context>{doc_context}</source_context></source>\n"
context_string = context_string.strip()
prompt = get_last_user_message(form_data["messages"])
@@ -1082,8 +1115,6 @@ async def process_chat_response(
for filter_id in get_sorted_filter_ids(model)
]
print(f"{filter_functions=}")
# Streaming response
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID.
@@ -1563,7 +1594,9 @@ async def process_chat_response(
value = delta.get("content")
reasoning_content = delta.get("reasoning_content")
reasoning_content = delta.get(
"reasoning_content"
) or delta.get("reasoning")
if reasoning_content:
if (
not content_blocks
@@ -1766,18 +1799,36 @@ async def process_chat_response(
spec = tool.get("spec", {})
try:
required_params = spec.get("parameters", {}).get(
"required", []
allowed_params = (
spec.get("parameters", {})
.get("properties", {})
.keys()
)
tool_function = tool["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in required_params
if k in allowed_params
}
tool_result = await tool_function(
**tool_function_params
)
if tool.get("direct", False):
tool_result = await tool_function(
**tool_function_params
)
else:
tool_result = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"params": tool_function_params,
"session_id": metadata.get(
"session_id", None
),
},
}
)
except Exception as e:
tool_result = str(e)

View File

@@ -49,6 +49,7 @@ async def get_all_base_models(request: Request, user: UserModel = None):
"created": int(time.time()),
"owned_by": "ollama",
"ollama": model,
"tags": model.get("tags", []),
}
for model in ollama_models["models"]
]

View File

@@ -94,7 +94,7 @@ class OAuthManager:
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
oauth_roles = None
oauth_roles = []
# Default/fallback role if no matching roles are found
role = auth_manager_config.DEFAULT_USER_ROLE
@@ -104,7 +104,7 @@ class OAuthManager:
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
oauth_roles = claim_data if isinstance(claim_data, list) else None
oauth_roles = claim_data if isinstance(claim_data, list) else []
log.debug(f"Oauth Roles claim: {oauth_claim}")
log.debug(f"User roles from oauth: {oauth_roles}")
@@ -140,6 +140,7 @@ class OAuthManager:
log.debug("Running OAUTH Group management")
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
user_oauth_groups = []
# Nested claim search for groups claim
if oauth_claim:
claim_data = user_data
@@ -160,7 +161,7 @@ class OAuthManager:
# Remove groups that user is no longer a part of
for group_model in user_current_groups:
if group_model.name not in user_oauth_groups:
if user_oauth_groups and group_model.name not in user_oauth_groups:
# Remove group from user
log.debug(
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
@@ -186,8 +187,10 @@ class OAuthManager:
# Add user to new groups
for group_model in all_available_groups:
if group_model.name in user_oauth_groups and not any(
gm.name == group_model.name for gm in user_current_groups
if (
user_oauth_groups
and group_model.name in user_oauth_groups
and not any(gm.name == group_model.name for gm in user_current_groups)
):
# Add user to group
log.debug(

View File

@@ -110,6 +110,11 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
"num_thread": int,
}
# Extract keep_alive from options if it exists
if "options" in form_data and "keep_alive" in form_data["options"]:
form_data["keep_alive"] = form_data["options"]["keep_alive"]
del form_data["options"]["keep_alive"]
return apply_model_params_to_body(params, form_data, mappings)
@@ -231,6 +236,11 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
"system"
] # To prevent Ollama warning of invalid option provided
# Extract keep_alive from options if it exists
if "keep_alive" in ollama_options:
ollama_payload["keep_alive"] = ollama_options["keep_alive"]
del ollama_options["keep_alive"]
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
if "stop" in openai_payload:
ollama_options = ollama_payload.get("options", {})

View File

@@ -7,7 +7,7 @@ import types
import tempfile
import logging
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import SRC_LOG_LEVELS, PIP_OPTIONS, PIP_PACKAGE_INDEX_OPTIONS
from open_webui.models.functions import Functions
from open_webui.models.tools import Tools
@@ -165,15 +165,19 @@ def load_function_module_by_id(function_id, content=None):
os.unlink(temp_file.name)
def install_frontmatter_requirements(requirements):
def install_frontmatter_requirements(requirements: str):
if requirements:
try:
req_list = [req.strip() for req in requirements.split(",")]
for req in req_list:
log.info(f"Installing requirement: {req}")
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
log.info(f"Installing requirements: {' '.join(req_list)}")
subprocess.check_call(
[sys.executable, "-m", "pip", "install"]
+ PIP_OPTIONS
+ req_list
+ PIP_PACKAGE_INDEX_OPTIONS
)
except Exception as e:
log.error(f"Error installing package: {req}")
log.error(f"Error installing packages: {' '.join(req_list)}")
raise e
else:

View File

@@ -0,0 +1,26 @@
from opentelemetry.semconv.trace import SpanAttributes as _SpanAttributes
# Span Tags
SPAN_DB_TYPE = "mysql"
SPAN_REDIS_TYPE = "redis"
SPAN_DURATION = "duration"
SPAN_SQL_STR = "sql"
SPAN_SQL_EXPLAIN = "explain"
SPAN_ERROR_TYPE = "error"
class SpanAttributes(_SpanAttributes):
"""
Span Attributes
"""
DB_INSTANCE = "db.instance"
DB_TYPE = "db.type"
DB_IP = "db.ip"
DB_PORT = "db.port"
ERROR_KIND = "error.kind"
ERROR_OBJECT = "error.object"
ERROR_MESSAGE = "error.message"
RESULT_CODE = "result.code"
RESULT_MESSAGE = "result.message"
RESULT_ERRORS = "result.errors"

View File

@@ -0,0 +1,31 @@
import threading
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import BatchSpanProcessor
class LazyBatchSpanProcessor(BatchSpanProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.done = True
with self.condition:
self.condition.notify_all()
self.worker_thread.join()
self.done = False
self.worker_thread = None
def on_end(self, span: ReadableSpan) -> None:
if self.worker_thread is None:
self.worker_thread = threading.Thread(
name=self.__class__.__name__, target=self.worker, daemon=True
)
self.worker_thread.start()
super().on_end(span)
def shutdown(self) -> None:
self.done = True
with self.condition:
self.condition.notify_all()
if self.worker_thread:
self.worker_thread.join()
self.span_exporter.shutdown()

View File

@@ -0,0 +1,202 @@
import logging
import traceback
from typing import Collection, Union
from aiohttp import (
TraceRequestStartParams,
TraceRequestEndParams,
TraceRequestExceptionParams,
)
from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi
from fastapi import FastAPI
from opentelemetry.instrumentation.httpx import (
HTTPXClientInstrumentor,
RequestInfo,
ResponseInfo,
)
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
from opentelemetry.trace import Span, StatusCode
from redis import Redis
from requests import PreparedRequest, Response
from sqlalchemy import Engine
from fastapi import status
from open_webui.utils.telemetry.constants import SPAN_REDIS_TYPE, SpanAttributes
from open_webui.env import SRC_LOG_LEVELS
logger = logging.getLogger(__name__)
logger.setLevel(SRC_LOG_LEVELS["MAIN"])
def requests_hook(span: Span, request: PreparedRequest):
"""
Http Request Hook
"""
span.update_name(f"{request.method} {request.url}")
span.set_attributes(
attributes={
SpanAttributes.HTTP_URL: request.url,
SpanAttributes.HTTP_METHOD: request.method,
}
)
def response_hook(span: Span, request: PreparedRequest, response: Response):
"""
HTTP Response Hook
"""
span.set_attributes(
attributes={
SpanAttributes.HTTP_STATUS_CODE: response.status_code,
}
)
span.set_status(StatusCode.ERROR if response.status_code >= 400 else StatusCode.OK)
def redis_request_hook(span: Span, instance: Redis, args, kwargs):
"""
Redis Request Hook
"""
try:
connection_kwargs: dict = instance.connection_pool.connection_kwargs
host = connection_kwargs.get("host")
port = connection_kwargs.get("port")
db = connection_kwargs.get("db")
span.set_attributes(
{
SpanAttributes.DB_INSTANCE: f"{host}/{db}",
SpanAttributes.DB_NAME: f"{host}/{db}",
SpanAttributes.DB_TYPE: SPAN_REDIS_TYPE,
SpanAttributes.DB_PORT: port,
SpanAttributes.DB_IP: host,
SpanAttributes.DB_STATEMENT: " ".join([str(i) for i in args]),
SpanAttributes.DB_OPERATION: str(args[0]),
}
)
except Exception: # pylint: disable=W0718
logger.error(traceback.format_exc())
def httpx_request_hook(span: Span, request: RequestInfo):
"""
HTTPX Request Hook
"""
span.update_name(f"{request.method.decode()} {str(request.url)}")
span.set_attributes(
attributes={
SpanAttributes.HTTP_URL: str(request.url),
SpanAttributes.HTTP_METHOD: request.method.decode(),
}
)
def httpx_response_hook(span: Span, request: RequestInfo, response: ResponseInfo):
"""
HTTPX Response Hook
"""
span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.status_code)
span.set_status(
StatusCode.ERROR
if response.status_code >= status.HTTP_400_BAD_REQUEST
else StatusCode.OK
)
async def httpx_async_request_hook(span: Span, request: RequestInfo):
"""
Async Request Hook
"""
httpx_request_hook(span, request)
async def httpx_async_response_hook(
span: Span, request: RequestInfo, response: ResponseInfo
):
"""
Async Response Hook
"""
httpx_response_hook(span, request, response)
def aiohttp_request_hook(span: Span, request: TraceRequestStartParams):
"""
Aiohttp Request Hook
"""
span.update_name(f"{request.method} {str(request.url)}")
span.set_attributes(
attributes={
SpanAttributes.HTTP_URL: str(request.url),
SpanAttributes.HTTP_METHOD: request.method,
}
)
def aiohttp_response_hook(
span: Span, response: Union[TraceRequestExceptionParams, TraceRequestEndParams]
):
"""
Aiohttp Response Hook
"""
if isinstance(response, TraceRequestEndParams):
span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.response.status)
span.set_status(
StatusCode.ERROR
if response.response.status >= status.HTTP_400_BAD_REQUEST
else StatusCode.OK
)
elif isinstance(response, TraceRequestExceptionParams):
span.set_status(StatusCode.ERROR)
span.set_attribute(SpanAttributes.ERROR_MESSAGE, str(response.exception))
class Instrumentor(BaseInstrumentor):
"""
Instrument OT
"""
def __init__(self, app: FastAPI, db_engine: Engine):
self.app = app
self.db_engine = db_engine
def instrumentation_dependencies(self) -> Collection[str]:
return []
def _instrument(self, **kwargs):
instrument_fastapi(app=self.app)
SQLAlchemyInstrumentor().instrument(engine=self.db_engine)
RedisInstrumentor().instrument(request_hook=redis_request_hook)
RequestsInstrumentor().instrument(
request_hook=requests_hook, response_hook=response_hook
)
LoggingInstrumentor().instrument()
HTTPXClientInstrumentor().instrument(
request_hook=httpx_request_hook,
response_hook=httpx_response_hook,
async_request_hook=httpx_async_request_hook,
async_response_hook=httpx_async_response_hook,
)
AioHttpClientInstrumentor().instrument(
request_hook=aiohttp_request_hook,
response_hook=aiohttp_response_hook,
)
def _uninstrument(self, **kwargs):
if getattr(self, "instrumentors", None) is None:
return
for instrumentor in self.instrumentors:
instrumentor.uninstrument()

View File

@@ -0,0 +1,23 @@
from fastapi import FastAPI
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from sqlalchemy import Engine
from open_webui.utils.telemetry.exporters import LazyBatchSpanProcessor
from open_webui.utils.telemetry.instrumentors import Instrumentor
from open_webui.env import OTEL_SERVICE_NAME, OTEL_EXPORTER_OTLP_ENDPOINT
def setup(app: FastAPI, db_engine: Engine):
# set up trace
trace.set_tracer_provider(
TracerProvider(
resource=Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
)
)
# otlp export
exporter = OTLPSpanExporter(endpoint=OTEL_EXPORTER_OTLP_ENDPOINT)
trace.get_tracer_provider().add_span_processor(LazyBatchSpanProcessor(exporter))
Instrumentor(app=app, db_engine=db_engine).instrument()

View File

@@ -1,6 +1,9 @@
import inspect
import logging
import re
import inspect
import uuid
from typing import Any, Awaitable, Callable, get_type_hints
from functools import update_wrapper, partial

View File

@@ -37,13 +37,13 @@ asgiref==3.8.1
# AI libraries
openai
anthropic
google-generativeai==0.7.2
google-generativeai==0.8.4
tiktoken
langchain==0.3.19
langchain-community==0.3.18
fake-useragent==1.5.1
fake-useragent==2.1.0
chromadb==0.6.2
pymilvus==2.5.0
qdrant-client~=1.12.0
@@ -78,6 +78,7 @@ sentencepiece
soundfile==0.13.1
azure-ai-documentintelligence==1.0.0
pillow==11.1.0
opencv-python-headless==4.11.0.86
rapidocr-onnxruntime==1.3.24
rank-bm25==0.2.2
@@ -118,3 +119,16 @@ ldap3==2.9.1
## Firecrawl
firecrawl-py==1.12.0
## Trace
opentelemetry-api==1.30.0
opentelemetry-sdk==1.30.0
opentelemetry-exporter-otlp==1.30.0
opentelemetry-instrumentation==0.51b0
opentelemetry-instrumentation-fastapi==0.51b0
opentelemetry-instrumentation-sqlalchemy==0.51b0
opentelemetry-instrumentation-redis==0.51b0
opentelemetry-instrumentation-requests==0.51b0
opentelemetry-instrumentation-logging==0.51b0
opentelemetry-instrumentation-httpx==0.51b0
opentelemetry-instrumentation-aiohttp-client==0.51b0

View File

@@ -41,4 +41,5 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
:: Execute uvicorn
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
uvicorn open_webui.main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*'
uvicorn open_webui.main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*' --ws auto
:: For ssl user uvicorn open_webui.main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*' --ssl-keyfile "key.pem" --ssl-certfile "cert.pem" --ws auto