mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
Merge branch 'dev' into k_reranker
This commit is contained in:
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import base64
|
||||
import redis
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -17,6 +18,7 @@ from open_webui.env import (
|
||||
DATA_DIR,
|
||||
DATABASE_URL,
|
||||
ENV,
|
||||
REDIS_URL,
|
||||
FRONTEND_BUILD_DIR,
|
||||
OFFLINE_MODE,
|
||||
OPEN_WEBUI_DIR,
|
||||
@@ -248,9 +250,14 @@ class PersistentConfig(Generic[T]):
|
||||
|
||||
class AppConfig:
|
||||
_state: dict[str, PersistentConfig]
|
||||
_redis: Optional[redis.Redis] = None
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, redis_url: Optional[str] = None):
|
||||
super().__setattr__("_state", {})
|
||||
if redis_url:
|
||||
super().__setattr__(
|
||||
"_redis", redis.Redis.from_url(redis_url, decode_responses=True)
|
||||
)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, PersistentConfig):
|
||||
@@ -259,7 +266,31 @@ class AppConfig:
|
||||
self._state[key].value = value
|
||||
self._state[key].save()
|
||||
|
||||
if self._redis:
|
||||
redis_key = f"open-webui:config:{key}"
|
||||
self._redis.set(redis_key, json.dumps(self._state[key].value))
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self._state:
|
||||
raise AttributeError(f"Config key '{key}' not found")
|
||||
|
||||
# If Redis is available, check for an updated value
|
||||
if self._redis:
|
||||
redis_key = f"open-webui:config:{key}"
|
||||
redis_value = self._redis.get(redis_key)
|
||||
|
||||
if redis_value is not None:
|
||||
try:
|
||||
decoded_value = json.loads(redis_value)
|
||||
|
||||
# Update the in-memory value if different
|
||||
if self._state[key].value != decoded_value:
|
||||
self._state[key].value = decoded_value
|
||||
log.info(f"Updated {key} from Redis: {decoded_value}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
log.error(f"Invalid JSON format in Redis for {key}: {redis_value}")
|
||||
|
||||
return self._state[key].value
|
||||
|
||||
|
||||
@@ -1276,7 +1307,7 @@ Strictly return in JSON format:
|
||||
ENABLE_AUTOCOMPLETE_GENERATION = PersistentConfig(
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION",
|
||||
"task.autocomplete.enable",
|
||||
os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "True").lower() == "true",
|
||||
os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "False").lower() == "true",
|
||||
)
|
||||
|
||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = PersistentConfig(
|
||||
@@ -1548,8 +1579,10 @@ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
|
||||
OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True)
|
||||
OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
|
||||
OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", "true").lower() == "true"
|
||||
OPENSEARCH_CERT_VERIFY = (
|
||||
os.environ.get("OPENSEARCH_CERT_VERIFY", "false").lower() == "true"
|
||||
)
|
||||
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
|
||||
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
|
||||
|
||||
@@ -1623,6 +1656,12 @@ TIKA_SERVER_URL = PersistentConfig(
|
||||
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
|
||||
)
|
||||
|
||||
DOCLING_SERVER_URL = PersistentConfig(
|
||||
"DOCLING_SERVER_URL",
|
||||
"rag.docling_server_url",
|
||||
os.getenv("DOCLING_SERVER_URL", "http://docling:5001"),
|
||||
)
|
||||
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT",
|
||||
"rag.document_intelligence_endpoint",
|
||||
@@ -1955,6 +1994,12 @@ TAVILY_API_KEY = PersistentConfig(
|
||||
os.getenv("TAVILY_API_KEY", ""),
|
||||
)
|
||||
|
||||
TAVILY_EXTRACT_DEPTH = PersistentConfig(
|
||||
"TAVILY_EXTRACT_DEPTH",
|
||||
"rag.web.search.tavily_extract_depth",
|
||||
os.getenv("TAVILY_EXTRACT_DEPTH", "basic"),
|
||||
)
|
||||
|
||||
JINA_API_KEY = PersistentConfig(
|
||||
"JINA_API_KEY",
|
||||
"rag.web.search.jina_api_key",
|
||||
@@ -2041,6 +2086,12 @@ PLAYWRIGHT_WS_URI = PersistentConfig(
|
||||
os.environ.get("PLAYWRIGHT_WS_URI", None),
|
||||
)
|
||||
|
||||
PLAYWRIGHT_TIMEOUT = PersistentConfig(
|
||||
"PLAYWRIGHT_TIMEOUT",
|
||||
"rag.web.loader.engine.playwright.timeout",
|
||||
int(os.environ.get("PLAYWRIGHT_TIMEOUT", "10")),
|
||||
)
|
||||
|
||||
FIRECRAWL_API_KEY = PersistentConfig(
|
||||
"FIRECRAWL_API_KEY",
|
||||
"firecrawl.api_key",
|
||||
|
||||
@@ -105,7 +105,6 @@ for source in log_sources:
|
||||
|
||||
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
||||
|
||||
|
||||
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
||||
if WEBUI_NAME != "Open WebUI":
|
||||
WEBUI_NAME += " (Open WebUI)"
|
||||
@@ -130,7 +129,6 @@ else:
|
||||
except Exception:
|
||||
PACKAGE_DATA = {"version": "0.0.0"}
|
||||
|
||||
|
||||
VERSION = PACKAGE_DATA["version"]
|
||||
|
||||
|
||||
@@ -161,7 +159,6 @@ try:
|
||||
except Exception:
|
||||
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
|
||||
|
||||
|
||||
# Convert markdown content to HTML
|
||||
html_content = markdown.markdown(changelog_content)
|
||||
|
||||
@@ -192,7 +189,6 @@ for version in soup.find_all("h2"):
|
||||
|
||||
changelog_json[version_number] = version_data
|
||||
|
||||
|
||||
CHANGELOG = changelog_json
|
||||
|
||||
####################################
|
||||
@@ -209,7 +205,6 @@ ENABLE_FORWARD_USER_INFO_HEADERS = (
|
||||
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# WEBUI_BUILD_HASH
|
||||
####################################
|
||||
@@ -244,7 +239,6 @@ if FROM_INIT_PY:
|
||||
|
||||
DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
|
||||
|
||||
|
||||
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
|
||||
|
||||
FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
|
||||
@@ -256,7 +250,6 @@ if FROM_INIT_PY:
|
||||
os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
|
||||
).resolve()
|
||||
|
||||
|
||||
####################################
|
||||
# Database
|
||||
####################################
|
||||
@@ -321,7 +314,6 @@ RESET_CONFIG_ON_START = (
|
||||
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
ENABLE_REALTIME_CHAT_SAVE = (
|
||||
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
|
||||
)
|
||||
@@ -330,7 +322,7 @@ ENABLE_REALTIME_CHAT_SAVE = (
|
||||
# REDIS
|
||||
####################################
|
||||
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "")
|
||||
|
||||
####################################
|
||||
# WEBUI_AUTH (Required for security)
|
||||
@@ -399,18 +391,16 @@ else:
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
|
||||
"AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST",
|
||||
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""),
|
||||
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "10"),
|
||||
)
|
||||
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "":
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None
|
||||
else:
|
||||
try:
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = int(AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 5
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 10
|
||||
|
||||
####################################
|
||||
# OFFLINE_MODE
|
||||
@@ -424,13 +414,12 @@ if OFFLINE_MODE:
|
||||
####################################
|
||||
# AUDIT LOGGING
|
||||
####################################
|
||||
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
|
||||
# Where to store log file
|
||||
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
|
||||
# Maximum size of a file before rotating into a new log file
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
||||
# METADATA | REQUEST | REQUEST_RESPONSE
|
||||
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
|
||||
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
|
||||
try:
|
||||
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
|
||||
except ValueError:
|
||||
@@ -442,3 +431,26 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders"
|
||||
)
|
||||
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
|
||||
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
||||
|
||||
####################################
|
||||
# OPENTELEMETRY
|
||||
####################################
|
||||
|
||||
ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
|
||||
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
|
||||
)
|
||||
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
|
||||
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
|
||||
"OTEL_RESOURCE_ATTRIBUTES", ""
|
||||
) # e.g. key1=val1,key2=val2
|
||||
OTEL_TRACES_SAMPLER = os.environ.get(
|
||||
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
|
||||
).lower()
|
||||
|
||||
####################################
|
||||
# TOOLS/FUNCTIONS PIP OPTIONS
|
||||
####################################
|
||||
|
||||
PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split()
|
||||
PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()
|
||||
|
||||
@@ -223,6 +223,9 @@ async def generate_function_chat_completion(
|
||||
extra_params = {
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__chat_id__": metadata.get("chat_id", None),
|
||||
"__session_id__": metadata.get("session_id", None),
|
||||
"__message_id__": metadata.get("message_id", None),
|
||||
"__task__": __task__,
|
||||
"__task_body__": __task_body__,
|
||||
"__files__": files,
|
||||
|
||||
@@ -84,7 +84,7 @@ from open_webui.routers.retrieval import (
|
||||
get_rf,
|
||||
)
|
||||
|
||||
from open_webui.internal.db import Session
|
||||
from open_webui.internal.db import Session, engine
|
||||
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
@@ -155,6 +155,7 @@ from open_webui.config import (
|
||||
AUDIO_TTS_AZURE_SPEECH_REGION,
|
||||
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
PLAYWRIGHT_WS_URI,
|
||||
PLAYWRIGHT_TIMEOUT,
|
||||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY,
|
||||
RAG_WEB_LOADER_ENGINE,
|
||||
@@ -186,6 +187,7 @@ from open_webui.config import (
|
||||
CHUNK_SIZE,
|
||||
CONTENT_EXTRACTION_ENGINE,
|
||||
TIKA_SERVER_URL,
|
||||
DOCLING_SERVER_URL,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY,
|
||||
RAG_TOP_K,
|
||||
@@ -213,6 +215,7 @@ from open_webui.config import (
|
||||
SERPSTACK_API_KEY,
|
||||
SERPSTACK_HTTPS,
|
||||
TAVILY_API_KEY,
|
||||
TAVILY_EXTRACT_DEPTH,
|
||||
BING_SEARCH_V7_ENDPOINT,
|
||||
BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
@@ -313,6 +316,7 @@ from open_webui.env import (
|
||||
AUDIT_EXCLUDED_PATHS,
|
||||
AUDIT_LOG_LEVEL,
|
||||
CHANGELOG,
|
||||
REDIS_URL,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
MAX_BODY_LOG_SIZE,
|
||||
SAFE_MODE,
|
||||
@@ -328,6 +332,7 @@ from open_webui.env import (
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
RESET_CONFIG_ON_START,
|
||||
OFFLINE_MODE,
|
||||
ENABLE_OTEL,
|
||||
)
|
||||
|
||||
|
||||
@@ -355,7 +360,6 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
|
||||
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
Functions.deactivate_all_functions()
|
||||
@@ -419,11 +423,24 @@ app = FastAPI(
|
||||
|
||||
oauth_manager = OAuthManager(app)
|
||||
|
||||
app.state.config = AppConfig()
|
||||
app.state.config = AppConfig(redis_url=REDIS_URL)
|
||||
|
||||
app.state.WEBUI_NAME = WEBUI_NAME
|
||||
app.state.LICENSE_METADATA = None
|
||||
|
||||
|
||||
########################################
|
||||
#
|
||||
# OPENTELEMETRY
|
||||
#
|
||||
########################################
|
||||
|
||||
if ENABLE_OTEL:
|
||||
from open_webui.utils.telemetry.setup import setup as setup_opentelemetry
|
||||
|
||||
setup_opentelemetry(app=app, db_engine=engine)
|
||||
|
||||
|
||||
########################################
|
||||
#
|
||||
# OLLAMA
|
||||
@@ -551,6 +568,7 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
|
||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
||||
|
||||
@@ -614,8 +632,10 @@ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_
|
||||
app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE
|
||||
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
|
||||
app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
|
||||
app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT
|
||||
app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
|
||||
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
|
||||
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = None
|
||||
app.state.ef = None
|
||||
@@ -949,14 +969,24 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
return filtered_models
|
||||
|
||||
models = await get_all_models(request, user=user)
|
||||
all_models = await get_all_models(request, user=user)
|
||||
|
||||
# Filter out filter pipelines
|
||||
models = [
|
||||
model
|
||||
for model in models
|
||||
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
|
||||
]
|
||||
models = []
|
||||
for model in all_models:
|
||||
# Filter out filter pipelines
|
||||
if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
|
||||
continue
|
||||
|
||||
model_tags = [
|
||||
tag.get("name")
|
||||
for tag in model.get("info", {}).get("meta", {}).get("tags", [])
|
||||
]
|
||||
tags = [tag.get("name") for tag in model.get("tags", [])]
|
||||
|
||||
tags = list(set(model_tags + tags))
|
||||
model["tags"] = [{"name": tag} for tag in tags]
|
||||
|
||||
models.append(model)
|
||||
|
||||
model_order_list = request.app.state.config.MODEL_ORDER_LIST
|
||||
if model_order_list:
|
||||
|
||||
@@ -105,7 +105,7 @@ class TikaLoader:
|
||||
|
||||
if r.ok:
|
||||
raw_metadata = r.json()
|
||||
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
|
||||
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
|
||||
|
||||
if "Content-Type" in raw_metadata:
|
||||
headers["Content-Type"] = raw_metadata["Content-Type"]
|
||||
@@ -117,6 +117,52 @@ class TikaLoader:
|
||||
raise Exception(f"Error calling Tika: {r.reason}")
|
||||
|
||||
|
||||
class DoclingLoader:
|
||||
def __init__(self, url, file_path=None, mime_type=None):
|
||||
self.url = url.rstrip("/")
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {
|
||||
"files": (
|
||||
self.file_path,
|
||||
f,
|
||||
self.mime_type or "application/octet-stream",
|
||||
)
|
||||
}
|
||||
|
||||
params = {
|
||||
"image_export_mode": "placeholder",
|
||||
"table_mode": "accurate",
|
||||
}
|
||||
|
||||
endpoint = f"{self.url}/v1alpha/convert/file"
|
||||
r = requests.post(endpoint, files=files, data=params)
|
||||
|
||||
if r.ok:
|
||||
result = r.json()
|
||||
document_data = result.get("document", {})
|
||||
text = document_data.get("md_content", "<No text content found>")
|
||||
|
||||
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
|
||||
|
||||
log.debug("Docling extracted text: %s", text)
|
||||
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
else:
|
||||
error_msg = f"Error calling Docling API: {r.reason}"
|
||||
if r.text:
|
||||
try:
|
||||
error_data = r.json()
|
||||
if "detail" in error_data:
|
||||
error_msg += f" - {error_data['detail']}"
|
||||
except Exception:
|
||||
error_msg += f" - {r.text}"
|
||||
raise Exception(f"Error calling Docling: {error_msg}")
|
||||
|
||||
|
||||
class Loader:
|
||||
def __init__(self, engine: str = "", **kwargs):
|
||||
self.engine = engine
|
||||
@@ -149,6 +195,12 @@ class Loader:
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
)
|
||||
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
|
||||
loader = DoclingLoader(
|
||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||
|
||||
93
backend/open_webui/retrieval/loaders/tavily.py
Normal file
93
backend/open_webui/retrieval/loaders/tavily.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import requests
|
||||
import logging
|
||||
from typing import Iterator, List, Literal, Union
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class TavilyLoader(BaseLoader):
|
||||
"""Extract web page content from URLs using Tavily Extract API.
|
||||
|
||||
This is a LangChain document loader that uses Tavily's Extract API to
|
||||
retrieve content from web pages and return it as Document objects.
|
||||
|
||||
Args:
|
||||
urls: URL or list of URLs to extract content from.
|
||||
api_key: The Tavily API key.
|
||||
extract_depth: Depth of extraction, either "basic" or "advanced".
|
||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls: Union[str, List[str]],
|
||||
api_key: str,
|
||||
extract_depth: Literal["basic", "advanced"] = "basic",
|
||||
continue_on_failure: bool = True,
|
||||
) -> None:
|
||||
"""Initialize Tavily Extract client.
|
||||
|
||||
Args:
|
||||
urls: URL or list of URLs to extract content from.
|
||||
api_key: The Tavily API key.
|
||||
include_images: Whether to include images in the extraction.
|
||||
extract_depth: Depth of extraction, either "basic" or "advanced".
|
||||
advanced extraction retrieves more data, including tables and
|
||||
embedded content, with higher success but may increase latency.
|
||||
basic costs 1 credit per 5 successful URL extractions,
|
||||
advanced costs 2 credits per 5 successful URL extractions.
|
||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||
"""
|
||||
if not urls:
|
||||
raise ValueError("At least one URL must be provided.")
|
||||
|
||||
self.api_key = api_key
|
||||
self.urls = urls if isinstance(urls, list) else [urls]
|
||||
self.extract_depth = extract_depth
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.api_url = "https://api.tavily.com/extract"
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Extract and yield documents from the URLs using Tavily Extract API."""
|
||||
batch_size = 20
|
||||
for i in range(0, len(self.urls), batch_size):
|
||||
batch_urls = self.urls[i : i + batch_size]
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
# Use string for single URL, array for multiple URLs
|
||||
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
|
||||
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
|
||||
# Make the API call
|
||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
# Process successful results
|
||||
for result in response_data.get("results", []):
|
||||
url = result.get("url", "")
|
||||
content = result.get("raw_content", "")
|
||||
if not content:
|
||||
log.warning(f"No content extracted from {url}")
|
||||
continue
|
||||
# Add URLs as metadata
|
||||
metadata = {"source": url}
|
||||
yield Document(
|
||||
page_content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
for failed in response_data.get("failed_results", []):
|
||||
url = failed.get("url", "")
|
||||
error = failed.get("error", "Unknown error")
|
||||
log.error(f"Failed to extract content from {url}: {error}")
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.error(f"Error extracting content from batch {batch_urls}: {e}")
|
||||
else:
|
||||
raise e
|
||||
@@ -189,8 +189,7 @@ def merge_and_sort_query_results(
|
||||
query_results: list[dict], k: int, reverse: bool = False
|
||||
) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined = []
|
||||
seen_hashes = set() # To store unique document hashes
|
||||
combined = dict() # To store documents with unique document hashes
|
||||
|
||||
for data in query_results:
|
||||
distances = data["distances"][0]
|
||||
@@ -203,10 +202,19 @@ def merge_and_sort_query_results(
|
||||
document.encode()
|
||||
).hexdigest() # Compute a hash for uniqueness
|
||||
|
||||
if doc_hash not in seen_hashes:
|
||||
seen_hashes.add(doc_hash)
|
||||
combined.append((distance, document, metadata))
|
||||
if doc_hash not in combined.keys():
|
||||
combined[doc_hash] = (distance, document, metadata)
|
||||
continue # if doc is new, no further comparison is needed
|
||||
|
||||
# if doc is alredy in, but new distance is better, update
|
||||
if not reverse and distance < combined[doc_hash][0]:
|
||||
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
|
||||
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
|
||||
combined[doc_hash] = (distance, document, metadata)
|
||||
if reverse and distance > combined[doc_hash][0]:
|
||||
combined[doc_hash] = (distance, document, metadata)
|
||||
|
||||
combined = list(combined.values())
|
||||
# Sort the list based on distances
|
||||
combined.sort(key=lambda x: x[0], reverse=reverse)
|
||||
|
||||
@@ -215,6 +223,12 @@ def merge_and_sort_query_results(
|
||||
zip(*combined[:k]) if combined else ([], [], [])
|
||||
)
|
||||
|
||||
# if chromaDB, the distance is 0 (best) to 2 (worse)
|
||||
# re-order to -1 (worst) to 1 (best) for relevance score
|
||||
if not reverse:
|
||||
sorted_distances = tuple(-dist for dist in sorted_distances)
|
||||
sorted_distances = tuple(dist + 1 for dist in sorted_distances)
|
||||
|
||||
# Create and return the output dictionary
|
||||
return {
|
||||
"distances": [list(sorted_distances)],
|
||||
@@ -306,13 +320,8 @@ def query_collection_with_hybrid_search(
|
||||
raise Exception(
|
||||
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if VECTOR_DB == "chroma":
|
||||
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
|
||||
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
|
||||
return merge_and_sort_query_results(results, k=k_reranker, reverse=False)
|
||||
else:
|
||||
return merge_and_sort_query_results(results, k=k_reranker, reverse=True)
|
||||
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
|
||||
@@ -166,12 +166,19 @@ class ChromaClient:
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
if ids:
|
||||
collection.delete(ids=ids)
|
||||
elif filter:
|
||||
collection.delete(where=filter)
|
||||
try:
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
if ids:
|
||||
collection.delete(ids=ids)
|
||||
elif filter:
|
||||
collection.delete(where=filter)
|
||||
except Exception as e:
|
||||
# If collection doesn't exist, that's fine - nothing to delete
|
||||
log.debug(
|
||||
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
|
||||
)
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
@@ -21,7 +22,13 @@ class OpenSearchClient:
|
||||
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
|
||||
)
|
||||
|
||||
def _get_index_name(self, collection_name: str) -> str:
|
||||
return f"{self.index_prefix}_{collection_name}"
|
||||
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
if not result["hits"]["hits"]:
|
||||
return None
|
||||
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
@@ -31,9 +38,12 @@ class OpenSearchClient:
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
def _result_to_search_result(self, result) -> SearchResult:
|
||||
if not result["hits"]["hits"]:
|
||||
return None
|
||||
|
||||
ids = []
|
||||
distances = []
|
||||
documents = []
|
||||
@@ -46,34 +56,40 @@ class OpenSearchClient:
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
ids=[ids],
|
||||
distances=[distances],
|
||||
documents=[documents],
|
||||
metadatas=[metadatas],
|
||||
)
|
||||
|
||||
def _create_index(self, collection_name: str, dimension: int):
|
||||
body = {
|
||||
"settings": {"index": {"knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"id": {"type": "keyword"},
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": dimension, # Adjust based on your vector dimensions
|
||||
"index": true,
|
||||
"type": "knn_vector",
|
||||
"dimension": dimension, # Adjust based on your vector dimensions
|
||||
"index": True,
|
||||
"similarity": "faiss",
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "ip", # Use inner product to approximate cosine similarity
|
||||
"space_type": "innerproduct", # Use inner product to approximate cosine similarity
|
||||
"engine": "faiss",
|
||||
"ef_construction": 128,
|
||||
"m": 16,
|
||||
"parameters": {
|
||||
"ef_construction": 128,
|
||||
"m": 16,
|
||||
},
|
||||
},
|
||||
},
|
||||
"text": {"type": "text"},
|
||||
"metadata": {"type": "object"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
self.client.indices.create(
|
||||
index=f"{self.index_prefix}_{collection_name}", body=body
|
||||
index=self._get_index_name(collection_name), body=body
|
||||
)
|
||||
|
||||
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||
@@ -83,39 +99,45 @@ class OpenSearchClient:
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# has_collection here means has index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
return self.client.indices.exists(
|
||||
index=f"{self.index_prefix}_{collection_name}"
|
||||
)
|
||||
return self.client.indices.exists(index=self._get_index_name(collection_name))
|
||||
|
||||
def delete_colleciton(self, collection_name: str):
|
||||
def delete_collection(self, collection_name: str):
|
||||
# delete_collection here means delete index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
self.client.indices.delete(index=f"{self.index_prefix}_{collection_name}")
|
||||
self.client.indices.delete(index=self._get_index_name(collection_name))
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float]], limit: int
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||
"params": {
|
||||
"vector": vectors[0]
|
||||
}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{collection_name}", body=query
|
||||
)
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0",
|
||||
"params": {
|
||||
"field": "vector",
|
||||
"query_value": vectors[0],
|
||||
}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
result = self.client.search(
|
||||
index=self._get_index_name(collection_name), body=query
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
@@ -129,13 +151,15 @@ class OpenSearchClient:
|
||||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"match": {"metadata." + str(field): value}}
|
||||
)
|
||||
|
||||
size = limit if limit else 10
|
||||
|
||||
try:
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{collection_name}",
|
||||
index=self._get_index_name(collection_name),
|
||||
body=query_body,
|
||||
size=size,
|
||||
)
|
||||
@@ -146,14 +170,14 @@ class OpenSearchClient:
|
||||
return None
|
||||
|
||||
def _create_index_if_not_exists(self, collection_name: str, dimension: int):
|
||||
if not self.has_index(collection_name):
|
||||
if not self.has_collection(collection_name):
|
||||
self._create_index(collection_name, dimension)
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
|
||||
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{collection_name}", body=query
|
||||
index=self._get_index_name(collection_name), body=query
|
||||
)
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
@@ -165,18 +189,18 @@ class OpenSearchClient:
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"index": {
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
"_op_type": "index",
|
||||
"_index": self._get_index_name(collection_name),
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
self.client.bulk(actions)
|
||||
bulk(self.client, actions)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
self._create_index_if_not_exists(
|
||||
@@ -186,26 +210,47 @@ class OpenSearchClient:
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"index": {
|
||||
"_id": item["id"],
|
||||
"_index": f"{self.index_prefix}_{collection_name}",
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
"_op_type": "update",
|
||||
"_index": self._get_index_name(collection_name),
|
||||
"_id": item["id"],
|
||||
"doc": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
"doc_as_upsert": True,
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
self.client.bulk(actions)
|
||||
bulk(self.client, actions)
|
||||
|
||||
def delete(self, collection_name: str, ids: list[str]):
|
||||
actions = [
|
||||
{"delete": {"_index": f"{self.index_prefix}_{collection_name}", "_id": id}}
|
||||
for id in ids
|
||||
]
|
||||
self.client.bulk(body=actions)
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
if ids:
|
||||
actions = [
|
||||
{
|
||||
"_op_type": "delete",
|
||||
"_index": self._get_index_name(collection_name),
|
||||
"_id": id,
|
||||
}
|
||||
for id in ids
|
||||
]
|
||||
bulk(self.client, actions)
|
||||
elif filter:
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
}
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"match": {"metadata." + str(field): value}}
|
||||
)
|
||||
self.client.delete_by_query(
|
||||
index=self._get_index_name(collection_name), body=query_body
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
|
||||
|
||||
@@ -24,13 +24,17 @@ from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoa
|
||||
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import (
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
PLAYWRIGHT_WS_URI,
|
||||
PLAYWRIGHT_TIMEOUT,
|
||||
RAG_WEB_LOADER_ENGINE,
|
||||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY,
|
||||
TAVILY_API_KEY,
|
||||
TAVILY_EXTRACT_DEPTH,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
@@ -113,7 +117,47 @@ def verify_ssl_cert(url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class SafeFireCrawlLoader(BaseLoader):
|
||||
class RateLimitMixin:
|
||||
async def _wait_for_rate_limit(self):
|
||||
"""Wait to respect the rate limit if specified."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
await asyncio.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
def _sync_wait_for_rate_limit(self):
|
||||
"""Synchronous version of rate limit wait."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
time.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
|
||||
class URLProcessingMixin:
|
||||
def _verify_ssl_cert(self, url: str) -> bool:
|
||||
"""Verify SSL certificate for a URL."""
|
||||
return verify_ssl_cert(url)
|
||||
|
||||
async def _safe_process_url(self, url: str) -> bool:
|
||||
"""Perform safety checks before processing a URL."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
await self._wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
def _safe_process_url_sync(self, url: str) -> bool:
|
||||
"""Synchronous version of safety checks."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
self._sync_wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
|
||||
class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
def __init__(
|
||||
self,
|
||||
web_paths,
|
||||
@@ -184,7 +228,7 @@ class SafeFireCrawlLoader(BaseLoader):
|
||||
yield from loader.lazy_load()
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
continue
|
||||
raise e
|
||||
|
||||
@@ -204,47 +248,124 @@ class SafeFireCrawlLoader(BaseLoader):
|
||||
yield document
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
continue
|
||||
raise e
|
||||
|
||||
def _verify_ssl_cert(self, url: str) -> bool:
|
||||
return verify_ssl_cert(url)
|
||||
|
||||
async def _wait_for_rate_limit(self):
|
||||
"""Wait to respect the rate limit if specified."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
await asyncio.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
def __init__(
|
||||
self,
|
||||
web_paths: Union[str, List[str]],
|
||||
api_key: str,
|
||||
extract_depth: Literal["basic", "advanced"] = "basic",
|
||||
continue_on_failure: bool = True,
|
||||
requests_per_second: Optional[float] = None,
|
||||
verify_ssl: bool = True,
|
||||
trust_env: bool = False,
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Initialize SafeTavilyLoader with rate limiting and SSL verification support.
|
||||
|
||||
def _sync_wait_for_rate_limit(self):
|
||||
"""Synchronous version of rate limit wait."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
time.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
Args:
|
||||
web_paths: List of URLs/paths to process.
|
||||
api_key: The Tavily API key.
|
||||
extract_depth: Depth of extraction ("basic" or "advanced").
|
||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||
requests_per_second: Number of requests per second to limit to.
|
||||
verify_ssl: If True, verify SSL certificates.
|
||||
trust_env: If True, use proxy settings from environment variables.
|
||||
proxy: Optional proxy configuration.
|
||||
"""
|
||||
# Initialize proxy configuration if using environment variables
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
|
||||
async def _safe_process_url(self, url: str) -> bool:
|
||||
"""Perform safety checks before processing a URL."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
await self._wait_for_rate_limit()
|
||||
return True
|
||||
# Store parameters for creating TavilyLoader instances
|
||||
self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
|
||||
self.api_key = api_key
|
||||
self.extract_depth = extract_depth
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.verify_ssl = verify_ssl
|
||||
self.trust_env = trust_env
|
||||
self.proxy = proxy
|
||||
|
||||
def _safe_process_url_sync(self, url: str) -> bool:
|
||||
"""Synchronous version of safety checks."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
self._sync_wait_for_rate_limit()
|
||||
return True
|
||||
# Add rate limiting
|
||||
self.requests_per_second = requests_per_second
|
||||
self.last_request_time = None
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Load documents with rate limiting support, delegating to TavilyLoader."""
|
||||
valid_urls = []
|
||||
for url in self.web_paths:
|
||||
try:
|
||||
self._safe_process_url_sync(url)
|
||||
valid_urls.append(url)
|
||||
except Exception as e:
|
||||
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
if not valid_urls:
|
||||
if self.continue_on_failure:
|
||||
log.warning("No valid URLs to process after SSL verification")
|
||||
return
|
||||
raise ValueError("No valid URLs to process after SSL verification")
|
||||
try:
|
||||
loader = TavilyLoader(
|
||||
urls=valid_urls,
|
||||
api_key=self.api_key,
|
||||
extract_depth=self.extract_depth,
|
||||
continue_on_failure=self.continue_on_failure,
|
||||
)
|
||||
yield from loader.lazy_load()
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error extracting content from URLs: {e}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Async version with rate limiting and SSL verification."""
|
||||
valid_urls = []
|
||||
for url in self.web_paths:
|
||||
try:
|
||||
await self._safe_process_url(url)
|
||||
valid_urls.append(url)
|
||||
except Exception as e:
|
||||
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
if not valid_urls:
|
||||
if self.continue_on_failure:
|
||||
log.warning("No valid URLs to process after SSL verification")
|
||||
return
|
||||
raise ValueError("No valid URLs to process after SSL verification")
|
||||
|
||||
try:
|
||||
loader = TavilyLoader(
|
||||
urls=valid_urls,
|
||||
api_key=self.api_key,
|
||||
extract_depth=self.extract_depth,
|
||||
continue_on_failure=self.continue_on_failure,
|
||||
)
|
||||
async for document in loader.alazy_load():
|
||||
yield document
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error loading URLs: {e}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessingMixin):
|
||||
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
|
||||
|
||||
Attributes:
|
||||
@@ -256,6 +377,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
headless (bool): If True, the browser will run in headless mode.
|
||||
proxy (dict): Proxy override settings for the Playwright session.
|
||||
playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
|
||||
playwright_timeout (Optional[int]): Maximum operation time in milliseconds.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -269,6 +391,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
remove_selectors: Optional[List[str]] = None,
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
playwright_ws_url: Optional[str] = None,
|
||||
playwright_timeout: Optional[int] = 10000,
|
||||
):
|
||||
"""Initialize with additional safety parameters and remote browser support."""
|
||||
|
||||
@@ -295,6 +418,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
self.last_request_time = None
|
||||
self.playwright_ws_url = playwright_ws_url
|
||||
self.trust_env = trust_env
|
||||
self.playwright_timeout = playwright_timeout
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Safely load URLs synchronously with support for remote browser."""
|
||||
@@ -311,7 +435,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
try:
|
||||
self._safe_process_url_sync(url)
|
||||
page = browser.new_page()
|
||||
response = page.goto(url)
|
||||
response = page.goto(url, timeout=self.playwright_timeout)
|
||||
if response is None:
|
||||
raise ValueError(f"page.goto() returned None for url {url}")
|
||||
|
||||
@@ -320,7 +444,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
continue
|
||||
raise e
|
||||
browser.close()
|
||||
@@ -342,7 +466,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
try:
|
||||
await self._safe_process_url(url)
|
||||
page = await browser.new_page()
|
||||
response = await page.goto(url)
|
||||
response = await page.goto(url, timeout=self.playwright_timeout)
|
||||
if response is None:
|
||||
raise ValueError(f"page.goto() returned None for url {url}")
|
||||
|
||||
@@ -351,46 +475,11 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
continue
|
||||
raise e
|
||||
await browser.close()
|
||||
|
||||
def _verify_ssl_cert(self, url: str) -> bool:
|
||||
return verify_ssl_cert(url)
|
||||
|
||||
async def _wait_for_rate_limit(self):
|
||||
"""Wait to respect the rate limit if specified."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
await asyncio.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
def _sync_wait_for_rate_limit(self):
|
||||
"""Synchronous version of rate limit wait."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
time.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
async def _safe_process_url(self, url: str) -> bool:
|
||||
"""Perform safety checks before processing a URL."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
await self._wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
def _safe_process_url_sync(self, url: str) -> bool:
|
||||
"""Synchronous version of safety checks."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
self._sync_wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
|
||||
class SafeWebBaseLoader(WebBaseLoader):
|
||||
"""WebBaseLoader with enhanced error handling for URLs."""
|
||||
@@ -472,7 +561,7 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
# Log the error and continue with the next URL
|
||||
log.exception(e, "Error loading %s", path)
|
||||
log.exception(f"Error loading {path}: {e}")
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Async lazy load text from the url(s) in web_path."""
|
||||
@@ -499,6 +588,7 @@ RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
|
||||
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
|
||||
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
|
||||
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
|
||||
RAG_WEB_LOADER_ENGINES["tavily"] = SafeTavilyLoader
|
||||
|
||||
|
||||
def get_web_loader(
|
||||
@@ -518,13 +608,19 @@ def get_web_loader(
|
||||
"trust_env": trust_env,
|
||||
}
|
||||
|
||||
if PLAYWRIGHT_WS_URI.value:
|
||||
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
|
||||
if RAG_WEB_LOADER_ENGINE.value == "playwright":
|
||||
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value * 1000
|
||||
if PLAYWRIGHT_WS_URI.value:
|
||||
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
|
||||
|
||||
if RAG_WEB_LOADER_ENGINE.value == "firecrawl":
|
||||
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
|
||||
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
|
||||
|
||||
if RAG_WEB_LOADER_ENGINE.value == "tavily":
|
||||
web_loader_args["api_key"] = TAVILY_API_KEY.value
|
||||
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
|
||||
|
||||
# Create the appropriate WebLoader based on the configuration
|
||||
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
|
||||
web_loader = WebLoaderClass(**web_loader_args)
|
||||
|
||||
@@ -625,7 +625,9 @@ def transcription(
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
|
||||
supported_filetypes = ("audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a")
|
||||
|
||||
if not file.content_type.startswith(supported_filetypes):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
||||
|
||||
@@ -210,7 +210,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
LDAP_APP_DN,
|
||||
LDAP_APP_PASSWORD,
|
||||
auto_bind="NONE",
|
||||
authentication="SIMPLE",
|
||||
authentication="SIMPLE" if LDAP_APP_DN else "ANONYMOUS",
|
||||
)
|
||||
if not connection_app.bind():
|
||||
raise HTTPException(400, detail="Application account bind failed")
|
||||
|
||||
@@ -2,6 +2,8 @@ import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from open_webui.socket.main import get_event_emitter
|
||||
from open_webui.models.chats import (
|
||||
ChatForm,
|
||||
ChatImportForm,
|
||||
@@ -372,6 +374,107 @@ async def update_chat_by_id(
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateChatMessageById
|
||||
############################
|
||||
class MessageForm(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
|
||||
async def update_chat_message_by_id(
|
||||
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
if chat.user_id != user.id and user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
chat = Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
id,
|
||||
message_id,
|
||||
{
|
||||
"content": form_data.content,
|
||||
},
|
||||
)
|
||||
|
||||
event_emitter = get_event_emitter(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"chat_id": id,
|
||||
"message_id": message_id,
|
||||
},
|
||||
False,
|
||||
)
|
||||
|
||||
if event_emitter:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:message",
|
||||
"data": {
|
||||
"chat_id": id,
|
||||
"message_id": message_id,
|
||||
"content": form_data.content,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return ChatResponse(**chat.model_dump())
|
||||
|
||||
|
||||
############################
|
||||
# SendChatMessageEventById
|
||||
############################
|
||||
class EventForm(BaseModel):
|
||||
type: str
|
||||
data: dict
|
||||
|
||||
|
||||
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
|
||||
async def send_chat_message_event_by_id(
|
||||
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
if chat.user_id != user.id and user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
event_emitter = get_event_emitter(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"chat_id": id,
|
||||
"message_id": message_id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
if event_emitter:
|
||||
await event_emitter(form_data.model_dump())
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
############################
|
||||
# DeleteChatById
|
||||
############################
|
||||
|
||||
@@ -81,7 +81,7 @@ def upload_file(
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
elif file.content_type not in ["image/png", "image/jpeg", "image/gif"]:
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
|
||||
@@ -517,10 +517,8 @@ async def image_generations(
|
||||
images = []
|
||||
|
||||
for image in res["data"]:
|
||||
if "url" in image:
|
||||
image_data, content_type = load_url_image_data(
|
||||
image["url"], headers
|
||||
)
|
||||
if image_url := image.get("url", None):
|
||||
image_data, content_type = load_url_image_data(image_url, headers)
|
||||
else:
|
||||
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||
|
||||
|
||||
@@ -437,14 +437,24 @@ def remove_file_from_knowledge_by_id(
|
||||
)
|
||||
|
||||
# Remove content from the vector database
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
||||
)
|
||||
try:
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug("This was most likely caused by bypassing embedding processing")
|
||||
log.debug(e)
|
||||
pass
|
||||
|
||||
# Remove the file's collection from vector database
|
||||
file_collection = f"file-{form_data.file_id}"
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
|
||||
try:
|
||||
# Remove the file's collection from vector database
|
||||
file_collection = f"file-{form_data.file_id}"
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
|
||||
except Exception as e:
|
||||
log.debug("This was most likely caused by bypassing embedding processing")
|
||||
log.debug(e)
|
||||
pass
|
||||
|
||||
# Delete file from database
|
||||
Files.delete_file_by_id(form_data.file_id)
|
||||
|
||||
@@ -295,7 +295,7 @@ async def update_config(
|
||||
}
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
@cached(ttl=1)
|
||||
async def get_all_models(request: Request, user: UserModel = None):
|
||||
log.info("get_all_models()")
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
@@ -336,6 +336,7 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
tags = api_config.get("tags", [])
|
||||
model_ids = api_config.get("model_ids", [])
|
||||
|
||||
if len(model_ids) != 0 and "models" in response:
|
||||
@@ -350,6 +351,10 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
for model in response.get("models", []):
|
||||
model["model"] = f"{prefix_id}.{model['model']}"
|
||||
|
||||
if tags:
|
||||
for model in response.get("models", []):
|
||||
model["tags"] = tags
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
@@ -1164,7 +1169,7 @@ async def generate_chat_completion(
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
|
||||
# payload["keep_alive"] = -1 # keep alive forever
|
||||
return await send_post_request(
|
||||
url=f"{url}/api/chat",
|
||||
payload=json.dumps(payload),
|
||||
|
||||
@@ -36,6 +36,9 @@ from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_openai,
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
convert_logit_bias_input_to_json,
|
||||
)
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access
|
||||
@@ -350,6 +353,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||
)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
tags = api_config.get("tags", [])
|
||||
|
||||
if prefix_id:
|
||||
for model in (
|
||||
@@ -357,6 +361,12 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||
):
|
||||
model["id"] = f"{prefix_id}.{model['id']}"
|
||||
|
||||
if tags:
|
||||
for model in (
|
||||
response if isinstance(response, list) else response.get("data", [])
|
||||
):
|
||||
model["tags"] = tags
|
||||
|
||||
log.debug(f"get_all_models:responses() {responses}")
|
||||
return responses
|
||||
|
||||
@@ -374,7 +384,7 @@ async def get_filtered_models(models, user):
|
||||
return filtered_models
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
@cached(ttl=1)
|
||||
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||
log.info("get_all_models()")
|
||||
|
||||
@@ -396,6 +406,7 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||
|
||||
for idx, models in enumerate(model_lists):
|
||||
if models is not None and "error" not in models:
|
||||
|
||||
merged_list.extend(
|
||||
[
|
||||
{
|
||||
@@ -406,18 +417,21 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||
"urlIdx": idx,
|
||||
}
|
||||
for model in models
|
||||
if "api.openai.com"
|
||||
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
or not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
if (model.get("id") or model.get("name"))
|
||||
and (
|
||||
"api.openai.com"
|
||||
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
or not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -666,6 +680,11 @@ async def generate_chat_completion(
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
if "logit_bias" in payload:
|
||||
payload["logit_bias"] = json.loads(
|
||||
convert_logit_bias_input_to_json(payload["logit_bias"])
|
||||
)
|
||||
|
||||
payload = json.dumps(payload)
|
||||
|
||||
r = None
|
||||
|
||||
@@ -90,8 +90,8 @@ async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
response.raise_for_status()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
res = (
|
||||
await response.json()
|
||||
@@ -139,8 +139,8 @@ async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
response.raise_for_status()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
try:
|
||||
res = (
|
||||
|
||||
@@ -358,6 +358,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"content_extraction": {
|
||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
|
||||
"document_intelligence_config": {
|
||||
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
@@ -428,6 +429,7 @@ class DocumentIntelligenceConfigForm(BaseModel):
|
||||
class ContentExtractionConfig(BaseModel):
|
||||
engine: str = ""
|
||||
tika_server_url: Optional[str] = None
|
||||
docling_server_url: Optional[str] = None
|
||||
document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
|
||||
|
||||
|
||||
@@ -540,6 +542,9 @@ async def update_rag_config(
|
||||
request.app.state.config.TIKA_SERVER_URL = (
|
||||
form_data.content_extraction.tika_server_url
|
||||
)
|
||||
request.app.state.config.DOCLING_SERVER_URL = (
|
||||
form_data.content_extraction.docling_server_url
|
||||
)
|
||||
if form_data.content_extraction.document_intelligence_config is not None:
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||||
form_data.content_extraction.document_intelligence_config.endpoint
|
||||
@@ -648,6 +653,7 @@ async def update_rag_config(
|
||||
"content_extraction": {
|
||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
|
||||
"document_intelligence_config": {
|
||||
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
@@ -994,6 +1000,7 @@ def process_file(
|
||||
loader = Loader(
|
||||
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.models.auths import Auths
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.users import (
|
||||
UserModel,
|
||||
@@ -17,7 +18,10 @@ from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
|
||||
from open_webui.utils.access_control import get_permissions
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -45,7 +49,7 @@ async def get_users(
|
||||
|
||||
@router.get("/groups")
|
||||
async def get_user_groups(user=Depends(get_verified_user)):
|
||||
return Users.get_user_groups(user.id)
|
||||
return Groups.get_groups_by_member_id(user.id)
|
||||
|
||||
|
||||
############################
|
||||
@@ -54,8 +58,12 @@ async def get_user_groups(user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
@router.get("/permissions")
|
||||
async def get_user_permissisions(user=Depends(get_verified_user)):
|
||||
return Users.get_user_groups(user.id)
|
||||
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
return user_permissions
|
||||
|
||||
|
||||
############################
|
||||
@@ -89,7 +97,7 @@ class UserPermissions(BaseModel):
|
||||
|
||||
|
||||
@router.get("/default/permissions", response_model=UserPermissions)
|
||||
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
||||
async def get_default_user_permissions(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"workspace": WorkspacePermissions(
|
||||
**request.app.state.config.USER_PERMISSIONS.get("workspace", {})
|
||||
@@ -104,7 +112,7 @@ async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
||||
|
||||
|
||||
@router.post("/default/permissions")
|
||||
async def update_user_permissions(
|
||||
async def update_default_user_permissions(
|
||||
request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
|
||||
|
||||
@@ -269,11 +269,19 @@ async def disconnect(sid):
|
||||
# print(f"Unknown session ID {sid} disconnected")
|
||||
|
||||
|
||||
def get_event_emitter(request_info):
|
||||
def get_event_emitter(request_info, update_db=True):
|
||||
async def __event_emitter__(event_data):
|
||||
user_id = request_info["user_id"]
|
||||
|
||||
session_ids = list(
|
||||
set(USER_POOL.get(user_id, []) + [request_info["session_id"]])
|
||||
set(
|
||||
USER_POOL.get(user_id, [])
|
||||
+ (
|
||||
[request_info.get("session_id")]
|
||||
if request_info.get("session_id")
|
||||
else []
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for session_id in session_ids:
|
||||
@@ -287,40 +295,41 @@ def get_event_emitter(request_info):
|
||||
to=session_id,
|
||||
)
|
||||
|
||||
if "type" in event_data and event_data["type"] == "status":
|
||||
Chats.add_message_status_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
event_data.get("data", {}),
|
||||
)
|
||||
if update_db:
|
||||
if "type" in event_data and event_data["type"] == "status":
|
||||
Chats.add_message_status_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
event_data.get("data", {}),
|
||||
)
|
||||
|
||||
if "type" in event_data and event_data["type"] == "message":
|
||||
message = Chats.get_message_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
)
|
||||
if "type" in event_data and event_data["type"] == "message":
|
||||
message = Chats.get_message_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
)
|
||||
|
||||
content = message.get("content", "")
|
||||
content += event_data.get("data", {}).get("content", "")
|
||||
content = message.get("content", "")
|
||||
content += event_data.get("data", {}).get("content", "")
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
if "type" in event_data and event_data["type"] == "replace":
|
||||
content = event_data.get("data", {}).get("content", "")
|
||||
if "type" in event_data and event_data["type"] == "replace":
|
||||
content = event_data.get("data", {}).get("content", "")
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
return __event_emitter__
|
||||
|
||||
|
||||
@@ -106,6 +106,7 @@ async def process_filter_functions(
|
||||
|
||||
# Handle file cleanup for inlet
|
||||
if skip_files and "files" in form_data.get("metadata", {}):
|
||||
del form_data["files"]
|
||||
del form_data["metadata"]["files"]
|
||||
|
||||
return form_data, {}
|
||||
|
||||
@@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
request: Request, body: dict, user: UserModel, models, tools
|
||||
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
|
||||
) -> tuple[dict, dict]:
|
||||
async def get_content_from_response(response) -> Optional[str]:
|
||||
content = None
|
||||
@@ -135,6 +135,9 @@ async def chat_completion_tools_handler(
|
||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||
}
|
||||
|
||||
event_caller = extra_params["__event_call__"]
|
||||
metadata = extra_params["__metadata__"]
|
||||
|
||||
task_model_id = get_task_model_id(
|
||||
body["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
@@ -189,19 +192,33 @@ async def chat_completion_tools_handler(
|
||||
tool_function_params = tool_call.get("parameters", {})
|
||||
|
||||
try:
|
||||
required_params = (
|
||||
tools[tool_function_name]
|
||||
.get("spec", {})
|
||||
.get("parameters", {})
|
||||
.get("required", [])
|
||||
tool = tools[tool_function_name]
|
||||
|
||||
spec = tool.get("spec", {})
|
||||
allowed_params = (
|
||||
spec.get("parameters", {}).get("properties", {}).keys()
|
||||
)
|
||||
tool_function = tools[tool_function_name]["callable"]
|
||||
tool_function = tool["callable"]
|
||||
tool_function_params = {
|
||||
k: v
|
||||
for k, v in tool_function_params.items()
|
||||
if k in required_params
|
||||
if k in allowed_params
|
||||
}
|
||||
tool_output = await tool_function(**tool_function_params)
|
||||
|
||||
if tool.get("direct", False):
|
||||
tool_output = await tool_function(**tool_function_params)
|
||||
else:
|
||||
tool_output = await event_caller(
|
||||
{
|
||||
"type": "execute:tool",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"tool": tool,
|
||||
"params": tool_function_params,
|
||||
"session_id": metadata.get("session_id", None),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
tool_output = str(e)
|
||||
@@ -767,12 +784,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
}
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
# Server side tools
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
# Client side tools
|
||||
tool_specs = form_data.get("tool_specs", None)
|
||||
|
||||
log.debug(f"{tool_ids=}")
|
||||
log.debug(f"{tool_specs=}")
|
||||
|
||||
tools_dict = {}
|
||||
|
||||
if tool_ids:
|
||||
# If tool_ids field is present, then get the tools
|
||||
tools = get_tools(
|
||||
tools_dict = get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
@@ -783,20 +806,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
log.info(f"{tools=}")
|
||||
log.info(f"{tools_dict=}")
|
||||
|
||||
if tool_specs:
|
||||
for tool in tool_specs:
|
||||
callable = tool.pop("callable", None)
|
||||
tools_dict[tool["name"]] = {
|
||||
"direct": True,
|
||||
"callable": callable,
|
||||
"spec": tool,
|
||||
}
|
||||
|
||||
if tools_dict:
|
||||
if metadata.get("function_calling") == "native":
|
||||
# If the function calling is native, then call the tools function calling handler
|
||||
metadata["tools"] = tools
|
||||
metadata["tools"] = tools_dict
|
||||
form_data["tools"] = [
|
||||
{"type": "function", "function": tool.get("spec", {})}
|
||||
for tool in tools.values()
|
||||
for tool in tools_dict.values()
|
||||
]
|
||||
else:
|
||||
# If the function calling is not native, then call the tools function calling handler
|
||||
try:
|
||||
form_data, flags = await chat_completion_tools_handler(
|
||||
request, form_data, user, models, tools
|
||||
request, form_data, extra_params, user, models, tools_dict
|
||||
)
|
||||
sources.extend(flags.get("sources", []))
|
||||
|
||||
@@ -815,7 +848,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
for source_idx, source in enumerate(sources):
|
||||
if "document" in source:
|
||||
for doc_idx, doc_context in enumerate(source["document"]):
|
||||
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
context_string += f"<source><source_id>{source_idx + 1}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
|
||||
context_string = context_string.strip()
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
@@ -1082,8 +1115,6 @@ async def process_chat_response(
|
||||
for filter_id in get_sorted_filter_ids(model)
|
||||
]
|
||||
|
||||
print(f"{filter_functions=}")
|
||||
|
||||
# Streaming response
|
||||
if event_emitter and event_caller:
|
||||
task_id = str(uuid4()) # Create a unique task ID.
|
||||
@@ -1563,7 +1594,9 @@ async def process_chat_response(
|
||||
|
||||
value = delta.get("content")
|
||||
|
||||
reasoning_content = delta.get("reasoning_content")
|
||||
reasoning_content = delta.get(
|
||||
"reasoning_content"
|
||||
) or delta.get("reasoning")
|
||||
if reasoning_content:
|
||||
if (
|
||||
not content_blocks
|
||||
@@ -1766,18 +1799,36 @@ async def process_chat_response(
|
||||
spec = tool.get("spec", {})
|
||||
|
||||
try:
|
||||
required_params = spec.get("parameters", {}).get(
|
||||
"required", []
|
||||
allowed_params = (
|
||||
spec.get("parameters", {})
|
||||
.get("properties", {})
|
||||
.keys()
|
||||
)
|
||||
tool_function = tool["callable"]
|
||||
tool_function_params = {
|
||||
k: v
|
||||
for k, v in tool_function_params.items()
|
||||
if k in required_params
|
||||
if k in allowed_params
|
||||
}
|
||||
tool_result = await tool_function(
|
||||
**tool_function_params
|
||||
)
|
||||
|
||||
if tool.get("direct", False):
|
||||
tool_result = await tool_function(
|
||||
**tool_function_params
|
||||
)
|
||||
else:
|
||||
tool_result = await event_caller(
|
||||
{
|
||||
"type": "execute:tool",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"tool": tool,
|
||||
"params": tool_function_params,
|
||||
"session_id": metadata.get(
|
||||
"session_id", None
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = str(e)
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ async def get_all_base_models(request: Request, user: UserModel = None):
|
||||
"created": int(time.time()),
|
||||
"owned_by": "ollama",
|
||||
"ollama": model,
|
||||
"tags": model.get("tags", []),
|
||||
}
|
||||
for model in ollama_models["models"]
|
||||
]
|
||||
|
||||
@@ -94,7 +94,7 @@ class OAuthManager:
|
||||
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
|
||||
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
|
||||
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
|
||||
oauth_roles = None
|
||||
oauth_roles = []
|
||||
# Default/fallback role if no matching roles are found
|
||||
role = auth_manager_config.DEFAULT_USER_ROLE
|
||||
|
||||
@@ -104,7 +104,7 @@ class OAuthManager:
|
||||
nested_claims = oauth_claim.split(".")
|
||||
for nested_claim in nested_claims:
|
||||
claim_data = claim_data.get(nested_claim, {})
|
||||
oauth_roles = claim_data if isinstance(claim_data, list) else None
|
||||
oauth_roles = claim_data if isinstance(claim_data, list) else []
|
||||
|
||||
log.debug(f"Oauth Roles claim: {oauth_claim}")
|
||||
log.debug(f"User roles from oauth: {oauth_roles}")
|
||||
@@ -140,6 +140,7 @@ class OAuthManager:
|
||||
log.debug("Running OAUTH Group management")
|
||||
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
||||
|
||||
user_oauth_groups = []
|
||||
# Nested claim search for groups claim
|
||||
if oauth_claim:
|
||||
claim_data = user_data
|
||||
@@ -160,7 +161,7 @@ class OAuthManager:
|
||||
|
||||
# Remove groups that user is no longer a part of
|
||||
for group_model in user_current_groups:
|
||||
if group_model.name not in user_oauth_groups:
|
||||
if user_oauth_groups and group_model.name not in user_oauth_groups:
|
||||
# Remove group from user
|
||||
log.debug(
|
||||
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
||||
@@ -186,8 +187,10 @@ class OAuthManager:
|
||||
|
||||
# Add user to new groups
|
||||
for group_model in all_available_groups:
|
||||
if group_model.name in user_oauth_groups and not any(
|
||||
gm.name == group_model.name for gm in user_current_groups
|
||||
if (
|
||||
user_oauth_groups
|
||||
and group_model.name in user_oauth_groups
|
||||
and not any(gm.name == group_model.name for gm in user_current_groups)
|
||||
):
|
||||
# Add user to group
|
||||
log.debug(
|
||||
|
||||
@@ -110,6 +110,11 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||
"num_thread": int,
|
||||
}
|
||||
|
||||
# Extract keep_alive from options if it exists
|
||||
if "options" in form_data and "keep_alive" in form_data["options"]:
|
||||
form_data["keep_alive"] = form_data["options"]["keep_alive"]
|
||||
del form_data["options"]["keep_alive"]
|
||||
|
||||
return apply_model_params_to_body(params, form_data, mappings)
|
||||
|
||||
|
||||
@@ -231,6 +236,11 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
"system"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
|
||||
# Extract keep_alive from options if it exists
|
||||
if "keep_alive" in ollama_options:
|
||||
ollama_payload["keep_alive"] = ollama_options["keep_alive"]
|
||||
del ollama_options["keep_alive"]
|
||||
|
||||
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
|
||||
if "stop" in openai_payload:
|
||||
ollama_options = ollama_payload.get("options", {})
|
||||
|
||||
@@ -7,7 +7,7 @@ import types
|
||||
import tempfile
|
||||
import logging
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.env import SRC_LOG_LEVELS, PIP_OPTIONS, PIP_PACKAGE_INDEX_OPTIONS
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.tools import Tools
|
||||
|
||||
@@ -165,15 +165,19 @@ def load_function_module_by_id(function_id, content=None):
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
|
||||
def install_frontmatter_requirements(requirements):
|
||||
def install_frontmatter_requirements(requirements: str):
|
||||
if requirements:
|
||||
try:
|
||||
req_list = [req.strip() for req in requirements.split(",")]
|
||||
for req in req_list:
|
||||
log.info(f"Installing requirement: {req}")
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
|
||||
log.info(f"Installing requirements: {' '.join(req_list)}")
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install"]
|
||||
+ PIP_OPTIONS
|
||||
+ req_list
|
||||
+ PIP_PACKAGE_INDEX_OPTIONS
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error installing package: {req}")
|
||||
log.error(f"Error installing packages: {' '.join(req_list)}")
|
||||
raise e
|
||||
|
||||
else:
|
||||
|
||||
0
backend/open_webui/utils/telemetry/__init__.py
Normal file
0
backend/open_webui/utils/telemetry/__init__.py
Normal file
26
backend/open_webui/utils/telemetry/constants.py
Normal file
26
backend/open_webui/utils/telemetry/constants.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from opentelemetry.semconv.trace import SpanAttributes as _SpanAttributes
|
||||
|
||||
# Span Tags
|
||||
SPAN_DB_TYPE = "mysql"
|
||||
SPAN_REDIS_TYPE = "redis"
|
||||
SPAN_DURATION = "duration"
|
||||
SPAN_SQL_STR = "sql"
|
||||
SPAN_SQL_EXPLAIN = "explain"
|
||||
SPAN_ERROR_TYPE = "error"
|
||||
|
||||
|
||||
class SpanAttributes(_SpanAttributes):
|
||||
"""
|
||||
Span Attributes
|
||||
"""
|
||||
|
||||
DB_INSTANCE = "db.instance"
|
||||
DB_TYPE = "db.type"
|
||||
DB_IP = "db.ip"
|
||||
DB_PORT = "db.port"
|
||||
ERROR_KIND = "error.kind"
|
||||
ERROR_OBJECT = "error.object"
|
||||
ERROR_MESSAGE = "error.message"
|
||||
RESULT_CODE = "result.code"
|
||||
RESULT_MESSAGE = "result.message"
|
||||
RESULT_ERRORS = "result.errors"
|
||||
31
backend/open_webui/utils/telemetry/exporters.py
Normal file
31
backend/open_webui/utils/telemetry/exporters.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import threading
|
||||
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
|
||||
|
||||
class LazyBatchSpanProcessor(BatchSpanProcessor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.done = True
|
||||
with self.condition:
|
||||
self.condition.notify_all()
|
||||
self.worker_thread.join()
|
||||
self.done = False
|
||||
self.worker_thread = None
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
if self.worker_thread is None:
|
||||
self.worker_thread = threading.Thread(
|
||||
name=self.__class__.__name__, target=self.worker, daemon=True
|
||||
)
|
||||
self.worker_thread.start()
|
||||
super().on_end(span)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.done = True
|
||||
with self.condition:
|
||||
self.condition.notify_all()
|
||||
if self.worker_thread:
|
||||
self.worker_thread.join()
|
||||
self.span_exporter.shutdown()
|
||||
202
backend/open_webui/utils/telemetry/instrumentors.py
Normal file
202
backend/open_webui/utils/telemetry/instrumentors.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Collection, Union
|
||||
|
||||
from aiohttp import (
|
||||
TraceRequestStartParams,
|
||||
TraceRequestEndParams,
|
||||
TraceRequestExceptionParams,
|
||||
)
|
||||
from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi
|
||||
from fastapi import FastAPI
|
||||
from opentelemetry.instrumentation.httpx import (
|
||||
HTTPXClientInstrumentor,
|
||||
RequestInfo,
|
||||
ResponseInfo,
|
||||
)
|
||||
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
|
||||
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
||||
from opentelemetry.instrumentation.redis import RedisInstrumentor
|
||||
from opentelemetry.instrumentation.requests import RequestsInstrumentor
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
|
||||
from opentelemetry.trace import Span, StatusCode
|
||||
from redis import Redis
|
||||
from requests import PreparedRequest, Response
|
||||
from sqlalchemy import Engine
|
||||
from fastapi import status
|
||||
|
||||
from open_webui.utils.telemetry.constants import SPAN_REDIS_TYPE, SpanAttributes
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def requests_hook(span: Span, request: PreparedRequest):
|
||||
"""
|
||||
Http Request Hook
|
||||
"""
|
||||
|
||||
span.update_name(f"{request.method} {request.url}")
|
||||
span.set_attributes(
|
||||
attributes={
|
||||
SpanAttributes.HTTP_URL: request.url,
|
||||
SpanAttributes.HTTP_METHOD: request.method,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def response_hook(span: Span, request: PreparedRequest, response: Response):
|
||||
"""
|
||||
HTTP Response Hook
|
||||
"""
|
||||
|
||||
span.set_attributes(
|
||||
attributes={
|
||||
SpanAttributes.HTTP_STATUS_CODE: response.status_code,
|
||||
}
|
||||
)
|
||||
span.set_status(StatusCode.ERROR if response.status_code >= 400 else StatusCode.OK)
|
||||
|
||||
|
||||
def redis_request_hook(span: Span, instance: Redis, args, kwargs):
|
||||
"""
|
||||
Redis Request Hook
|
||||
"""
|
||||
|
||||
try:
|
||||
connection_kwargs: dict = instance.connection_pool.connection_kwargs
|
||||
host = connection_kwargs.get("host")
|
||||
port = connection_kwargs.get("port")
|
||||
db = connection_kwargs.get("db")
|
||||
span.set_attributes(
|
||||
{
|
||||
SpanAttributes.DB_INSTANCE: f"{host}/{db}",
|
||||
SpanAttributes.DB_NAME: f"{host}/{db}",
|
||||
SpanAttributes.DB_TYPE: SPAN_REDIS_TYPE,
|
||||
SpanAttributes.DB_PORT: port,
|
||||
SpanAttributes.DB_IP: host,
|
||||
SpanAttributes.DB_STATEMENT: " ".join([str(i) for i in args]),
|
||||
SpanAttributes.DB_OPERATION: str(args[0]),
|
||||
}
|
||||
)
|
||||
except Exception: # pylint: disable=W0718
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def httpx_request_hook(span: Span, request: RequestInfo):
|
||||
"""
|
||||
HTTPX Request Hook
|
||||
"""
|
||||
|
||||
span.update_name(f"{request.method.decode()} {str(request.url)}")
|
||||
span.set_attributes(
|
||||
attributes={
|
||||
SpanAttributes.HTTP_URL: str(request.url),
|
||||
SpanAttributes.HTTP_METHOD: request.method.decode(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def httpx_response_hook(span: Span, request: RequestInfo, response: ResponseInfo):
|
||||
"""
|
||||
HTTPX Response Hook
|
||||
"""
|
||||
|
||||
span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.status_code)
|
||||
span.set_status(
|
||||
StatusCode.ERROR
|
||||
if response.status_code >= status.HTTP_400_BAD_REQUEST
|
||||
else StatusCode.OK
|
||||
)
|
||||
|
||||
|
||||
async def httpx_async_request_hook(span: Span, request: RequestInfo):
|
||||
"""
|
||||
Async Request Hook
|
||||
"""
|
||||
|
||||
httpx_request_hook(span, request)
|
||||
|
||||
|
||||
async def httpx_async_response_hook(
|
||||
span: Span, request: RequestInfo, response: ResponseInfo
|
||||
):
|
||||
"""
|
||||
Async Response Hook
|
||||
"""
|
||||
|
||||
httpx_response_hook(span, request, response)
|
||||
|
||||
|
||||
def aiohttp_request_hook(span: Span, request: TraceRequestStartParams):
|
||||
"""
|
||||
Aiohttp Request Hook
|
||||
"""
|
||||
|
||||
span.update_name(f"{request.method} {str(request.url)}")
|
||||
span.set_attributes(
|
||||
attributes={
|
||||
SpanAttributes.HTTP_URL: str(request.url),
|
||||
SpanAttributes.HTTP_METHOD: request.method,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def aiohttp_response_hook(
|
||||
span: Span, response: Union[TraceRequestExceptionParams, TraceRequestEndParams]
|
||||
):
|
||||
"""
|
||||
Aiohttp Response Hook
|
||||
"""
|
||||
|
||||
if isinstance(response, TraceRequestEndParams):
|
||||
span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.response.status)
|
||||
span.set_status(
|
||||
StatusCode.ERROR
|
||||
if response.response.status >= status.HTTP_400_BAD_REQUEST
|
||||
else StatusCode.OK
|
||||
)
|
||||
elif isinstance(response, TraceRequestExceptionParams):
|
||||
span.set_status(StatusCode.ERROR)
|
||||
span.set_attribute(SpanAttributes.ERROR_MESSAGE, str(response.exception))
|
||||
|
||||
|
||||
class Instrumentor(BaseInstrumentor):
|
||||
"""
|
||||
Instrument OT
|
||||
"""
|
||||
|
||||
def __init__(self, app: FastAPI, db_engine: Engine):
|
||||
self.app = app
|
||||
self.db_engine = db_engine
|
||||
|
||||
def instrumentation_dependencies(self) -> Collection[str]:
|
||||
return []
|
||||
|
||||
def _instrument(self, **kwargs):
|
||||
instrument_fastapi(app=self.app)
|
||||
SQLAlchemyInstrumentor().instrument(engine=self.db_engine)
|
||||
RedisInstrumentor().instrument(request_hook=redis_request_hook)
|
||||
RequestsInstrumentor().instrument(
|
||||
request_hook=requests_hook, response_hook=response_hook
|
||||
)
|
||||
LoggingInstrumentor().instrument()
|
||||
HTTPXClientInstrumentor().instrument(
|
||||
request_hook=httpx_request_hook,
|
||||
response_hook=httpx_response_hook,
|
||||
async_request_hook=httpx_async_request_hook,
|
||||
async_response_hook=httpx_async_response_hook,
|
||||
)
|
||||
AioHttpClientInstrumentor().instrument(
|
||||
request_hook=aiohttp_request_hook,
|
||||
response_hook=aiohttp_response_hook,
|
||||
)
|
||||
|
||||
def _uninstrument(self, **kwargs):
|
||||
if getattr(self, "instrumentors", None) is None:
|
||||
return
|
||||
for instrumentor in self.instrumentors:
|
||||
instrumentor.uninstrument()
|
||||
23
backend/open_webui/utils/telemetry/setup.py
Normal file
23
backend/open_webui/utils/telemetry/setup.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from fastapi import FastAPI
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from open_webui.utils.telemetry.exporters import LazyBatchSpanProcessor
|
||||
from open_webui.utils.telemetry.instrumentors import Instrumentor
|
||||
from open_webui.env import OTEL_SERVICE_NAME, OTEL_EXPORTER_OTLP_ENDPOINT
|
||||
|
||||
|
||||
def setup(app: FastAPI, db_engine: Engine):
|
||||
# set up trace
|
||||
trace.set_tracer_provider(
|
||||
TracerProvider(
|
||||
resource=Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
|
||||
)
|
||||
)
|
||||
# otlp export
|
||||
exporter = OTLPSpanExporter(endpoint=OTEL_EXPORTER_OTLP_ENDPOINT)
|
||||
trace.get_tracer_provider().add_span_processor(LazyBatchSpanProcessor(exporter))
|
||||
Instrumentor(app=app, db_engine=db_engine).instrument()
|
||||
@@ -1,6 +1,9 @@
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import inspect
|
||||
import uuid
|
||||
|
||||
from typing import Any, Awaitable, Callable, get_type_hints
|
||||
from functools import update_wrapper, partial
|
||||
|
||||
|
||||
@@ -37,13 +37,13 @@ asgiref==3.8.1
|
||||
# AI libraries
|
||||
openai
|
||||
anthropic
|
||||
google-generativeai==0.7.2
|
||||
google-generativeai==0.8.4
|
||||
tiktoken
|
||||
|
||||
langchain==0.3.19
|
||||
langchain-community==0.3.18
|
||||
|
||||
fake-useragent==1.5.1
|
||||
fake-useragent==2.1.0
|
||||
chromadb==0.6.2
|
||||
pymilvus==2.5.0
|
||||
qdrant-client~=1.12.0
|
||||
@@ -78,6 +78,7 @@ sentencepiece
|
||||
soundfile==0.13.1
|
||||
azure-ai-documentintelligence==1.0.0
|
||||
|
||||
pillow==11.1.0
|
||||
opencv-python-headless==4.11.0.86
|
||||
rapidocr-onnxruntime==1.3.24
|
||||
rank-bm25==0.2.2
|
||||
@@ -118,3 +119,16 @@ ldap3==2.9.1
|
||||
|
||||
## Firecrawl
|
||||
firecrawl-py==1.12.0
|
||||
|
||||
## Trace
|
||||
opentelemetry-api==1.30.0
|
||||
opentelemetry-sdk==1.30.0
|
||||
opentelemetry-exporter-otlp==1.30.0
|
||||
opentelemetry-instrumentation==0.51b0
|
||||
opentelemetry-instrumentation-fastapi==0.51b0
|
||||
opentelemetry-instrumentation-sqlalchemy==0.51b0
|
||||
opentelemetry-instrumentation-redis==0.51b0
|
||||
opentelemetry-instrumentation-requests==0.51b0
|
||||
opentelemetry-instrumentation-logging==0.51b0
|
||||
opentelemetry-instrumentation-httpx==0.51b0
|
||||
opentelemetry-instrumentation-aiohttp-client==0.51b0
|
||||
@@ -41,4 +41,5 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
|
||||
|
||||
:: Execute uvicorn
|
||||
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
|
||||
uvicorn open_webui.main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*'
|
||||
uvicorn open_webui.main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*' --ws auto
|
||||
:: For ssl user uvicorn open_webui.main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*' --ssl-keyfile "key.pem" --ssl-certfile "cert.pem" --ws auto
|
||||
|
||||
Reference in New Issue
Block a user