chore: format

This commit is contained in:
Timothy Jaeryang Baek
2025-02-20 01:01:29 -08:00
parent 2b913a99a3
commit eeb00a5ca2
60 changed files with 1548 additions and 1210 deletions

View File

@@ -15,15 +15,12 @@ from typing import (
Optional,
Sequence,
Union,
Literal
Literal,
)
import aiohttp
import certifi
import validators
from langchain_community.document_loaders import (
PlaywrightURLLoader,
WebBaseLoader
)
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
@@ -33,7 +30,7 @@ from open_webui.config import (
PLAYWRIGHT_WS_URI,
RAG_WEB_LOADER_ENGINE,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY
FIRECRAWL_API_KEY,
)
from open_webui.env import SRC_LOG_LEVELS
@@ -75,6 +72,7 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
continue
return valid_urls
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
@@ -85,16 +83,13 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses
def extract_metadata(soup, url):
metadata = {
"source": url
}
metadata = {"source": url}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
metadata["description"] = description.get("content", "No description found.")
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
return metadata
@@ -104,7 +99,7 @@ def verify_ssl_cert(url: str) -> bool:
"""Verify SSL certificate for the given URL."""
if not url.startswith("https://"):
return True
try:
hostname = url.split("://")[-1].split("/")[0]
context = ssl.create_default_context(cafile=certifi.where())
@@ -133,7 +128,7 @@ class SafeFireCrawlLoader(BaseLoader):
params: Optional[Dict] = None,
):
"""Concurrent document loader for FireCrawl operations.
Executes multiple FireCrawlLoader instances concurrently using thread pooling
to improve bulk processing efficiency.
Args:
@@ -142,7 +137,7 @@ class SafeFireCrawlLoader(BaseLoader):
trust_env: If True, use proxy settings from environment variables.
requests_per_second: Number of requests per second to limit to.
continue_on_failure (bool): If True, continue loading other URLs on failure.
api_key: API key for FireCrawl service. Defaults to None
api_key: API key for FireCrawl service. Defaults to None
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
mode: Operation mode selection:
@@ -154,15 +149,15 @@ class SafeFireCrawlLoader(BaseLoader):
Examples include crawlerOptions.
For more details, visit: https://github.com/mendableai/firecrawl-py
"""
proxy_server = proxy.get('server') if proxy else None
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy['server'] = env_proxy_server
proxy["server"] = env_proxy_server
else:
proxy = { 'server': env_proxy_server }
proxy = {"server": env_proxy_server}
self.web_paths = web_paths
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
@@ -184,7 +179,7 @@ class SafeFireCrawlLoader(BaseLoader):
api_key=self.api_key,
api_url=self.api_url,
mode=self.mode,
params=self.params
params=self.params,
)
yield from loader.lazy_load()
except Exception as e:
@@ -203,7 +198,7 @@ class SafeFireCrawlLoader(BaseLoader):
api_key=self.api_key,
api_url=self.api_url,
mode=self.mode,
params=self.params
params=self.params,
)
async for document in loader.alazy_load():
yield document
@@ -251,7 +246,7 @@ class SafeFireCrawlLoader(BaseLoader):
class SafePlaywrightURLLoader(PlaywrightURLLoader):
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
Attributes:
web_paths (List[str]): List of URLs to load.
verify_ssl (bool): If True, verify SSL certificates.
@@ -273,19 +268,19 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
headless: bool = True,
remove_selectors: Optional[List[str]] = None,
proxy: Optional[Dict[str, str]] = None,
playwright_ws_url: Optional[str] = None
playwright_ws_url: Optional[str] = None,
):
"""Initialize with additional safety parameters and remote browser support."""
proxy_server = proxy.get('server') if proxy else None
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy['server'] = env_proxy_server
proxy["server"] = env_proxy_server
else:
proxy = { 'server': env_proxy_server }
proxy = {"server": env_proxy_server}
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
super().__init__(
@@ -293,7 +288,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
continue_on_failure=continue_on_failure,
headless=headless if playwright_ws_url is None else False,
remove_selectors=remove_selectors,
proxy=proxy
proxy=proxy,
)
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
@@ -339,7 +334,9 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
if self.playwright_ws_url:
browser = await p.chromium.connect(self.playwright_ws_url)
else:
browser = await p.chromium.launch(headless=self.headless, proxy=self.proxy)
browser = await p.chromium.launch(
headless=self.headless, proxy=self.proxy
)
for url in self.urls:
try:
@@ -394,6 +391,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
self._sync_wait_for_rate_limit()
return True
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
@@ -496,11 +494,13 @@ class SafeWebBaseLoader(WebBaseLoader):
"""Load data into Document objects."""
return [document async for document in self.alazy_load()]
RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
def get_web_loader(
urls: Union[str, Sequence[str]],
verify_ssl: bool = True,
@@ -515,7 +515,7 @@ def get_web_loader(
"verify_ssl": verify_ssl,
"requests_per_second": requests_per_second,
"continue_on_failure": True,
"trust_env": trust_env
"trust_env": trust_env,
}
if PLAYWRIGHT_WS_URI.value:
@@ -529,6 +529,10 @@ def get_web_loader(
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
web_loader = WebLoaderClass(**web_loader_args)
log.debug("Using RAG_WEB_LOADER_ENGINE %s for %s URLs", web_loader.__class__.__name__, len(safe_urls))
log.debug(
"Using RAG_WEB_LOADER_ENGINE %s for %s URLs",
web_loader.__class__.__name__,
len(safe_urls),
)
return web_loader
return web_loader

View File

@@ -267,8 +267,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
try:
# print(payload)
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
json=payload,
@@ -325,8 +327,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
)
try:
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
json={
@@ -383,8 +387,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
<voice name="{language}">{payload["input"]}</voice>
</speak>"""
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
headers={

View File

@@ -547,7 +547,7 @@ async def signout(request: Request, response: Response):
response.delete_cookie("oauth_id_token")
return RedirectResponse(
headers=response.headers,
url=f"{logout_url}?id_token_hint={oauth_id_token}"
url=f"{logout_url}?id_token_hint={oauth_id_token}",
)
else:
raise HTTPException(

View File

@@ -944,8 +944,12 @@ class ChatMessage(BaseModel):
@classmethod
def check_at_least_one_field(cls, field_value, values, **kwargs):
# Raise an error if both 'content' and 'tool_calls' are None
if field_value is None and ("tool_calls" not in values or values["tool_calls"] is None):
raise ValueError("At least one of 'content' or 'tool_calls' must be provided")
if field_value is None and (
"tool_calls" not in values or values["tool_calls"] is None
):
raise ValueError(
"At least one of 'content' or 'tool_calls' must be provided"
)
return field_value

View File

@@ -253,23 +253,32 @@ class OAuthManager:
if provider == "github":
try:
access_token = token.get("access_token")
headers = {
"Authorization": f"Bearer {access_token}"
}
headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.get("https://api.github.com/user/emails", headers=headers) as resp:
async with session.get(
"https://api.github.com/user/emails", headers=headers
) as resp:
if resp.ok:
emails = await resp.json()
# use the primary email as the user's email
primary_email = next((e["email"] for e in emails if e.get("primary")), None)
primary_email = next(
(e["email"] for e in emails if e.get("primary")),
None,
)
if primary_email:
email = primary_email
else:
log.warning("No primary email found in GitHub response")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
log.warning(
"No primary email found in GitHub response"
)
raise HTTPException(
400, detail=ERROR_MESSAGES.INVALID_CRED
)
else:
log.warning("Failed to fetch GitHub email")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
raise HTTPException(
400, detail=ERROR_MESSAGES.INVALID_CRED
)
except Exception as e:
log.warning(f"Error fetching GitHub email: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)

View File

@@ -151,7 +151,7 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
# Put the content to empty string (Ollama requires an empty string for tool calls)
new_message["content"] = ""
else:
# Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL
content_text = ""
@@ -215,16 +215,20 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
if openai_payload.get("options"):
ollama_payload["options"] = openai_payload["options"]
ollama_options = openai_payload["options"]
# Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
if "max_tokens" in ollama_options:
ollama_options["num_predict"] = ollama_options["max_tokens"]
del ollama_options["max_tokens"] # To prevent Ollama warning of invalid option provided
ollama_options["num_predict"] = ollama_options["max_tokens"]
del ollama_options[
"max_tokens"
] # To prevent Ollama warning of invalid option provided
# Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down.
if "system" in ollama_options:
ollama_payload["system"] = ollama_options["system"]
del ollama_options["system"] # To prevent Ollama warning of invalid option provided
ollama_payload["system"] = ollama_options["system"]
del ollama_options[
"system"
] # To prevent Ollama warning of invalid option provided
if "metadata" in openai_payload:
ollama_payload["metadata"] = openai_payload["metadata"]

View File

@@ -23,6 +23,7 @@ def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict:
openai_tool_calls.append(openai_tool_call)
return openai_tool_calls
def convert_ollama_usage_to_openai(data: dict) -> dict:
return {
"response_token/s": (
@@ -56,24 +57,29 @@ def convert_ollama_usage_to_openai(data: dict) -> dict:
"total_duration": data.get("total_duration", 0),
"load_duration": data.get("load_duration", 0),
"prompt_eval_count": data.get("prompt_eval_count", 0),
"prompt_tokens": int(data.get("prompt_eval_count", 0)), # This is the OpenAI compatible key
"prompt_tokens": int(
data.get("prompt_eval_count", 0)
), # This is the OpenAI compatible key
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
"eval_count": data.get("eval_count", 0),
"completion_tokens": int(data.get("eval_count", 0)), # This is the OpenAI compatible key
"completion_tokens": int(
data.get("eval_count", 0)
), # This is the OpenAI compatible key
"eval_duration": data.get("eval_duration", 0),
"approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")(
(data.get("total_duration", 0) or 0) // 1_000_000_000
),
"total_tokens": int( # This is the OpenAI compatible key
"total_tokens": int( # This is the OpenAI compatible key
data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
),
"completion_tokens_details": { # This is the OpenAI compatible key
"completion_tokens_details": { # This is the OpenAI compatible key
"reasoning_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
"rejected_prediction_tokens": 0,
},
}
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
model = ollama_response.get("model", "ollama")
message_content = ollama_response.get("message", {}).get("content", "")