mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
chore: format
This commit is contained in:
@@ -15,15 +15,12 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
Literal
|
||||
Literal,
|
||||
)
|
||||
import aiohttp
|
||||
import certifi
|
||||
import validators
|
||||
from langchain_community.document_loaders import (
|
||||
PlaywrightURLLoader,
|
||||
WebBaseLoader
|
||||
)
|
||||
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
||||
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
@@ -33,7 +30,7 @@ from open_webui.config import (
|
||||
PLAYWRIGHT_WS_URI,
|
||||
RAG_WEB_LOADER_ENGINE,
|
||||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY
|
||||
FIRECRAWL_API_KEY,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
@@ -75,6 +72,7 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
|
||||
continue
|
||||
return valid_urls
|
||||
|
||||
|
||||
def resolve_hostname(hostname):
|
||||
# Get address information
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
@@ -85,16 +83,13 @@ def resolve_hostname(hostname):
|
||||
|
||||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
def extract_metadata(soup, url):
|
||||
metadata = {
|
||||
"source": url
|
||||
}
|
||||
metadata = {"source": url}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get(
|
||||
"content", "No description found."
|
||||
)
|
||||
metadata["description"] = description.get("content", "No description found.")
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
return metadata
|
||||
@@ -104,7 +99,7 @@ def verify_ssl_cert(url: str) -> bool:
|
||||
"""Verify SSL certificate for the given URL."""
|
||||
if not url.startswith("https://"):
|
||||
return True
|
||||
|
||||
|
||||
try:
|
||||
hostname = url.split("://")[-1].split("/")[0]
|
||||
context = ssl.create_default_context(cafile=certifi.where())
|
||||
@@ -133,7 +128,7 @@ class SafeFireCrawlLoader(BaseLoader):
|
||||
params: Optional[Dict] = None,
|
||||
):
|
||||
"""Concurrent document loader for FireCrawl operations.
|
||||
|
||||
|
||||
Executes multiple FireCrawlLoader instances concurrently using thread pooling
|
||||
to improve bulk processing efficiency.
|
||||
Args:
|
||||
@@ -142,7 +137,7 @@ class SafeFireCrawlLoader(BaseLoader):
|
||||
trust_env: If True, use proxy settings from environment variables.
|
||||
requests_per_second: Number of requests per second to limit to.
|
||||
continue_on_failure (bool): If True, continue loading other URLs on failure.
|
||||
api_key: API key for FireCrawl service. Defaults to None
|
||||
api_key: API key for FireCrawl service. Defaults to None
|
||||
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
|
||||
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
|
||||
mode: Operation mode selection:
|
||||
@@ -154,15 +149,15 @@ class SafeFireCrawlLoader(BaseLoader):
|
||||
Examples include crawlerOptions.
|
||||
For more details, visit: https://github.com/mendableai/firecrawl-py
|
||||
"""
|
||||
proxy_server = proxy.get('server') if proxy else None
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy['server'] = env_proxy_server
|
||||
proxy["server"] = env_proxy_server
|
||||
else:
|
||||
proxy = { 'server': env_proxy_server }
|
||||
proxy = {"server": env_proxy_server}
|
||||
self.web_paths = web_paths
|
||||
self.verify_ssl = verify_ssl
|
||||
self.requests_per_second = requests_per_second
|
||||
@@ -184,7 +179,7 @@ class SafeFireCrawlLoader(BaseLoader):
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
mode=self.mode,
|
||||
params=self.params
|
||||
params=self.params,
|
||||
)
|
||||
yield from loader.lazy_load()
|
||||
except Exception as e:
|
||||
@@ -203,7 +198,7 @@ class SafeFireCrawlLoader(BaseLoader):
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
mode=self.mode,
|
||||
params=self.params
|
||||
params=self.params,
|
||||
)
|
||||
async for document in loader.alazy_load():
|
||||
yield document
|
||||
@@ -251,7 +246,7 @@ class SafeFireCrawlLoader(BaseLoader):
|
||||
|
||||
class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
|
||||
|
||||
|
||||
Attributes:
|
||||
web_paths (List[str]): List of URLs to load.
|
||||
verify_ssl (bool): If True, verify SSL certificates.
|
||||
@@ -273,19 +268,19 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
headless: bool = True,
|
||||
remove_selectors: Optional[List[str]] = None,
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
playwright_ws_url: Optional[str] = None
|
||||
playwright_ws_url: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with additional safety parameters and remote browser support."""
|
||||
|
||||
proxy_server = proxy.get('server') if proxy else None
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy['server'] = env_proxy_server
|
||||
proxy["server"] = env_proxy_server
|
||||
else:
|
||||
proxy = { 'server': env_proxy_server }
|
||||
proxy = {"server": env_proxy_server}
|
||||
|
||||
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
|
||||
super().__init__(
|
||||
@@ -293,7 +288,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
continue_on_failure=continue_on_failure,
|
||||
headless=headless if playwright_ws_url is None else False,
|
||||
remove_selectors=remove_selectors,
|
||||
proxy=proxy
|
||||
proxy=proxy,
|
||||
)
|
||||
self.verify_ssl = verify_ssl
|
||||
self.requests_per_second = requests_per_second
|
||||
@@ -339,7 +334,9 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
if self.playwright_ws_url:
|
||||
browser = await p.chromium.connect(self.playwright_ws_url)
|
||||
else:
|
||||
browser = await p.chromium.launch(headless=self.headless, proxy=self.proxy)
|
||||
browser = await p.chromium.launch(
|
||||
headless=self.headless, proxy=self.proxy
|
||||
)
|
||||
|
||||
for url in self.urls:
|
||||
try:
|
||||
@@ -394,6 +391,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
self._sync_wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
|
||||
class SafeWebBaseLoader(WebBaseLoader):
|
||||
"""WebBaseLoader with enhanced error handling for URLs."""
|
||||
|
||||
@@ -496,11 +494,13 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
"""Load data into Document objects."""
|
||||
return [document async for document in self.alazy_load()]
|
||||
|
||||
|
||||
RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
|
||||
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
|
||||
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
|
||||
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
|
||||
|
||||
|
||||
def get_web_loader(
|
||||
urls: Union[str, Sequence[str]],
|
||||
verify_ssl: bool = True,
|
||||
@@ -515,7 +515,7 @@ def get_web_loader(
|
||||
"verify_ssl": verify_ssl,
|
||||
"requests_per_second": requests_per_second,
|
||||
"continue_on_failure": True,
|
||||
"trust_env": trust_env
|
||||
"trust_env": trust_env,
|
||||
}
|
||||
|
||||
if PLAYWRIGHT_WS_URI.value:
|
||||
@@ -529,6 +529,10 @@ def get_web_loader(
|
||||
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
|
||||
web_loader = WebLoaderClass(**web_loader_args)
|
||||
|
||||
log.debug("Using RAG_WEB_LOADER_ENGINE %s for %s URLs", web_loader.__class__.__name__, len(safe_urls))
|
||||
log.debug(
|
||||
"Using RAG_WEB_LOADER_ENGINE %s for %s URLs",
|
||||
web_loader.__class__.__name__,
|
||||
len(safe_urls),
|
||||
)
|
||||
|
||||
return web_loader
|
||||
return web_loader
|
||||
|
||||
@@ -267,8 +267,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
try:
|
||||
# print(payload)
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||
json=payload,
|
||||
@@ -325,8 +327,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
try:
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
|
||||
json={
|
||||
@@ -383,8 +387,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
||||
<voice name="{language}">{payload["input"]}</voice>
|
||||
</speak>"""
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
|
||||
headers={
|
||||
|
||||
@@ -547,7 +547,7 @@ async def signout(request: Request, response: Response):
|
||||
response.delete_cookie("oauth_id_token")
|
||||
return RedirectResponse(
|
||||
headers=response.headers,
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}"
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -944,8 +944,12 @@ class ChatMessage(BaseModel):
|
||||
@classmethod
|
||||
def check_at_least_one_field(cls, field_value, values, **kwargs):
|
||||
# Raise an error if both 'content' and 'tool_calls' are None
|
||||
if field_value is None and ("tool_calls" not in values or values["tool_calls"] is None):
|
||||
raise ValueError("At least one of 'content' or 'tool_calls' must be provided")
|
||||
if field_value is None and (
|
||||
"tool_calls" not in values or values["tool_calls"] is None
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of 'content' or 'tool_calls' must be provided"
|
||||
)
|
||||
|
||||
return field_value
|
||||
|
||||
|
||||
@@ -253,23 +253,32 @@ class OAuthManager:
|
||||
if provider == "github":
|
||||
try:
|
||||
access_token = token.get("access_token")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}"
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get("https://api.github.com/user/emails", headers=headers) as resp:
|
||||
async with session.get(
|
||||
"https://api.github.com/user/emails", headers=headers
|
||||
) as resp:
|
||||
if resp.ok:
|
||||
emails = await resp.json()
|
||||
# use the primary email as the user's email
|
||||
primary_email = next((e["email"] for e in emails if e.get("primary")), None)
|
||||
primary_email = next(
|
||||
(e["email"] for e in emails if e.get("primary")),
|
||||
None,
|
||||
)
|
||||
if primary_email:
|
||||
email = primary_email
|
||||
else:
|
||||
log.warning("No primary email found in GitHub response")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
log.warning(
|
||||
"No primary email found in GitHub response"
|
||||
)
|
||||
raise HTTPException(
|
||||
400, detail=ERROR_MESSAGES.INVALID_CRED
|
||||
)
|
||||
else:
|
||||
log.warning("Failed to fetch GitHub email")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
raise HTTPException(
|
||||
400, detail=ERROR_MESSAGES.INVALID_CRED
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning(f"Error fetching GitHub email: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
@@ -151,7 +151,7 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
||||
|
||||
# Put the content to empty string (Ollama requires an empty string for tool calls)
|
||||
new_message["content"] = ""
|
||||
|
||||
|
||||
else:
|
||||
# Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL
|
||||
content_text = ""
|
||||
@@ -215,16 +215,20 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
if openai_payload.get("options"):
|
||||
ollama_payload["options"] = openai_payload["options"]
|
||||
ollama_options = openai_payload["options"]
|
||||
|
||||
|
||||
# Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
||||
if "max_tokens" in ollama_options:
|
||||
ollama_options["num_predict"] = ollama_options["max_tokens"]
|
||||
del ollama_options["max_tokens"] # To prevent Ollama warning of invalid option provided
|
||||
ollama_options["num_predict"] = ollama_options["max_tokens"]
|
||||
del ollama_options[
|
||||
"max_tokens"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
|
||||
# Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down.
|
||||
if "system" in ollama_options:
|
||||
ollama_payload["system"] = ollama_options["system"]
|
||||
del ollama_options["system"] # To prevent Ollama warning of invalid option provided
|
||||
ollama_payload["system"] = ollama_options["system"]
|
||||
del ollama_options[
|
||||
"system"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
|
||||
if "metadata" in openai_payload:
|
||||
ollama_payload["metadata"] = openai_payload["metadata"]
|
||||
|
||||
@@ -23,6 +23,7 @@ def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict:
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
return openai_tool_calls
|
||||
|
||||
|
||||
def convert_ollama_usage_to_openai(data: dict) -> dict:
|
||||
return {
|
||||
"response_token/s": (
|
||||
@@ -56,24 +57,29 @@ def convert_ollama_usage_to_openai(data: dict) -> dict:
|
||||
"total_duration": data.get("total_duration", 0),
|
||||
"load_duration": data.get("load_duration", 0),
|
||||
"prompt_eval_count": data.get("prompt_eval_count", 0),
|
||||
"prompt_tokens": int(data.get("prompt_eval_count", 0)), # This is the OpenAI compatible key
|
||||
"prompt_tokens": int(
|
||||
data.get("prompt_eval_count", 0)
|
||||
), # This is the OpenAI compatible key
|
||||
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
|
||||
"eval_count": data.get("eval_count", 0),
|
||||
"completion_tokens": int(data.get("eval_count", 0)), # This is the OpenAI compatible key
|
||||
"completion_tokens": int(
|
||||
data.get("eval_count", 0)
|
||||
), # This is the OpenAI compatible key
|
||||
"eval_duration": data.get("eval_duration", 0),
|
||||
"approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")(
|
||||
(data.get("total_duration", 0) or 0) // 1_000_000_000
|
||||
),
|
||||
"total_tokens": int( # This is the OpenAI compatible key
|
||||
"total_tokens": int( # This is the OpenAI compatible key
|
||||
data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
|
||||
),
|
||||
"completion_tokens_details": { # This is the OpenAI compatible key
|
||||
"completion_tokens_details": { # This is the OpenAI compatible key
|
||||
"reasoning_tokens": 0,
|
||||
"accepted_prediction_tokens": 0,
|
||||
"rejected_prediction_tokens": 0
|
||||
}
|
||||
"rejected_prediction_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
model = ollama_response.get("model", "ollama")
|
||||
message_content = ollama_response.get("message", {}).get("content", "")
|
||||
|
||||
Reference in New Issue
Block a user