mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
Merge remote-tracking branch 'upstream/dev' into playwright
# Conflicts: # backend/open_webui/config.py # backend/open_webui/main.py # backend/open_webui/retrieval/web/utils.py # backend/open_webui/routers/retrieval.py # backend/open_webui/utils/middleware.py # pyproject.toml
This commit is contained in:
@@ -1190,6 +1190,12 @@ ENABLE_TAGS_GENERATION = PersistentConfig(
|
||||
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_TITLE_GENERATION = PersistentConfig(
|
||||
"ENABLE_TITLE_GENERATION",
|
||||
"task.title.enable",
|
||||
os.environ.get("ENABLE_TITLE_GENERATION", "True").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
|
||||
"ENABLE_SEARCH_QUERY_GENERATION",
|
||||
@@ -1803,6 +1809,18 @@ SEARCHAPI_ENGINE = PersistentConfig(
|
||||
os.getenv("SEARCHAPI_ENGINE", ""),
|
||||
)
|
||||
|
||||
SERPAPI_API_KEY = PersistentConfig(
|
||||
"SERPAPI_API_KEY",
|
||||
"rag.web.search.serpapi_api_key",
|
||||
os.getenv("SERPAPI_API_KEY", ""),
|
||||
)
|
||||
|
||||
SERPAPI_ENGINE = PersistentConfig(
|
||||
"SERPAPI_ENGINE",
|
||||
"rag.web.search.serpapi_engine",
|
||||
os.getenv("SERPAPI_ENGINE", ""),
|
||||
)
|
||||
|
||||
BING_SEARCH_V7_ENDPOINT = PersistentConfig(
|
||||
"BING_SEARCH_V7_ENDPOINT",
|
||||
"rag.web.search.bing_search_v7_endpoint",
|
||||
@@ -1841,6 +1859,12 @@ RAG_WEB_LOADER = PersistentConfig(
|
||||
os.environ.get("RAG_WEB_LOADER", "safe_web")
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_TRUST_ENV",
|
||||
"rag.web.search.trust_env",
|
||||
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False),
|
||||
)
|
||||
|
||||
PLAYWRIGHT_WS_URI = PersistentConfig(
|
||||
"PLAYWRIGHT_WS_URI",
|
||||
"rag.web.loader.playwright.ws.uri",
|
||||
|
||||
@@ -177,10 +177,13 @@ from open_webui.config import (
|
||||
RAG_WEB_SEARCH_ENGINE,
|
||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
RAG_WEB_SEARCH_TRUST_ENV,
|
||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
JINA_API_KEY,
|
||||
SEARCHAPI_API_KEY,
|
||||
SEARCHAPI_ENGINE,
|
||||
SERPAPI_API_KEY,
|
||||
SERPAPI_ENGINE,
|
||||
SEARXNG_QUERY_URL,
|
||||
SERPER_API_KEY,
|
||||
SERPLY_API_KEY,
|
||||
@@ -266,6 +269,7 @@ from open_webui.config import (
|
||||
TASK_MODEL,
|
||||
TASK_MODEL_EXTERNAL,
|
||||
ENABLE_TAGS_GENERATION,
|
||||
ENABLE_TITLE_GENERATION,
|
||||
ENABLE_SEARCH_QUERY_GENERATION,
|
||||
ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
@@ -347,12 +351,12 @@ class SPAStaticFiles(StaticFiles):
|
||||
|
||||
print(
|
||||
rf"""
|
||||
___ __ __ _ _ _ ___
|
||||
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
|
||||
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
|
||||
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
|
||||
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|
||||
|_|
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗ ██╗ ██╗███████╗██████╗ ██╗ ██╗██╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║ ██║ ██║██╔════╝██╔══██╗██║ ██║██║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║ ██║ █╗ ██║█████╗ ██████╔╝██║ ██║██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║ ██║███╗██║██╔══╝ ██╔══██╗██║ ██║██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║ ╚███╔███╔╝███████╗██████╔╝╚██████╔╝██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝ ╚══╝╚══╝ ╚══════╝╚═════╝ ╚═════╝ ╚═╝
|
||||
|
||||
|
||||
v{VERSION} - building the best open-source AI user interface.
|
||||
@@ -548,6 +552,8 @@ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
|
||||
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
|
||||
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
|
||||
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
|
||||
app.state.config.SERPAPI_API_KEY = SERPAPI_API_KEY
|
||||
app.state.config.SERPAPI_ENGINE = SERPAPI_ENGINE
|
||||
app.state.config.JINA_API_KEY = JINA_API_KEY
|
||||
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
|
||||
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
|
||||
@@ -556,6 +562,7 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
app.state.config.RAG_WEB_LOADER = RAG_WEB_LOADER
|
||||
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
|
||||
app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = None
|
||||
@@ -693,6 +700,7 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
|
||||
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
|
||||
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
||||
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
|
||||
|
||||
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
@@ -904,20 +912,30 @@ async def chat_completion(
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
|
||||
model_item = form_data.pop("model_item", {})
|
||||
tasks = form_data.pop("background_tasks", None)
|
||||
try:
|
||||
model_id = form_data.get("model", None)
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
model = request.app.state.MODELS[model_id]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
try:
|
||||
if not model_item.get("direct", False):
|
||||
model_id = form_data.get("model", None)
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
|
||||
model = request.app.state.MODELS[model_id]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
model = model_item
|
||||
model_info = None
|
||||
|
||||
request.state.direct = True
|
||||
request.state.model = model
|
||||
|
||||
metadata = {
|
||||
"user_id": user.id,
|
||||
@@ -929,6 +947,7 @@ async def chat_completion(
|
||||
"features": form_data.get("features", None),
|
||||
"variables": form_data.get("variables", None),
|
||||
"model": model_info,
|
||||
"direct": model_item.get("direct", False),
|
||||
**(
|
||||
{"function_calling": "native"}
|
||||
if form_data.get("params", {}).get("function_calling") == "native"
|
||||
@@ -940,6 +959,8 @@ async def chat_completion(
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
request.state.metadata = metadata
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
form_data, metadata, events = await process_chat_payload(
|
||||
@@ -947,6 +968,7 @@ async def chat_completion(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.debug(f"Error processing chat payload: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
@@ -975,6 +997,12 @@ async def chat_completed(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
model_item = form_data.pop("model_item", {})
|
||||
|
||||
if model_item.get("direct", False):
|
||||
request.state.direct = True
|
||||
request.state.model = model_item
|
||||
|
||||
return await chat_completed_handler(request, form_data, user)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
@@ -988,6 +1016,12 @@ async def chat_action(
|
||||
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
model_item = form_data.pop("model_item", {})
|
||||
|
||||
if model_item.get("direct", False):
|
||||
request.state.direct = True
|
||||
request.state.model = model_item
|
||||
|
||||
return await chat_action_handler(request, action_id, form_data, user)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -9,6 +9,7 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def _parse_response(response):
|
||||
result = {}
|
||||
if "data" in response:
|
||||
@@ -25,7 +26,8 @@ def _parse_response(response):
|
||||
"summary": item.get("summary", ""),
|
||||
"siteName": item.get("siteName", ""),
|
||||
"siteIcon": item.get("siteIcon", ""),
|
||||
"datePublished": item.get("datePublished", "") or item.get("dateLastCrawled", ""),
|
||||
"datePublished": item.get("datePublished", "")
|
||||
or item.get("dateLastCrawled", ""),
|
||||
}
|
||||
for item in webPages["value"]
|
||||
]
|
||||
@@ -42,17 +44,11 @@ def search_bocha(
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = json.dumps({
|
||||
"query": query,
|
||||
"summary": True,
|
||||
"freshness": "noLimit",
|
||||
"count": count
|
||||
})
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
payload = json.dumps(
|
||||
{"query": query, "summary": True, "freshness": "noLimit", "count": count}
|
||||
)
|
||||
|
||||
response = requests.post(url, headers=headers, data=payload, timeout=5)
|
||||
response.raise_for_status()
|
||||
@@ -63,10 +59,7 @@ def search_bocha(
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("name"),
|
||||
snippet=result.get("summary")
|
||||
link=result["url"], title=result.get("name"), snippet=result.get("summary")
|
||||
)
|
||||
for result in results.get("webpage", [])[:count]
|
||||
for result in results.get("webpage", [])[:count]
|
||||
]
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_google_pse(
|
||||
api_key: str,
|
||||
search_engine_id: str,
|
||||
@@ -46,12 +47,14 @@ def search_google_pse(
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get("items", [])
|
||||
if results: # check if results are returned. If not, no more pages to fetch.
|
||||
if results: # check if results are returned. If not, no more pages to fetch.
|
||||
all_results.extend(results)
|
||||
count -= len(results) # Decrement count by the number of results fetched in this page.
|
||||
start_index += 10 # Increment start index for the next page
|
||||
count -= len(
|
||||
results
|
||||
) # Decrement count by the number of results fetched in this page.
|
||||
start_index += 10 # Increment start index for the next page
|
||||
else:
|
||||
break # No more results from Google PSE, break the loop
|
||||
break # No more results from Google PSE, break the loop
|
||||
|
||||
if filter_list:
|
||||
all_results = get_filtered_results(all_results, filter_list)
|
||||
|
||||
48
backend/open_webui/retrieval/web/serpapi.py
Normal file
48
backend/open_webui/retrieval/web/serpapi.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serpapi(
|
||||
api_key: str,
|
||||
engine: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serpapi.com's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A serpapi.com API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://serpapi.com/search"
|
||||
|
||||
engine = engine or "google"
|
||||
|
||||
payload = {"engine": engine, "q": query, "api_key": api_key}
|
||||
|
||||
url = f"{url}?{urlencode(payload)}"
|
||||
response = requests.request("GET", url)
|
||||
|
||||
json_response = response.json()
|
||||
log.info(f"results from serpapi search: {json_response}")
|
||||
|
||||
results = sorted(
|
||||
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
@@ -8,7 +9,13 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
def search_tavily(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
# **kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
@@ -20,8 +27,8 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
"""
|
||||
url = "https://api.tavily.com/search"
|
||||
data = {"query": query, "api_key": api_key}
|
||||
|
||||
response = requests.post(url, json=data)
|
||||
include_domain = filter_list
|
||||
response = requests.post(url, include_domain, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
|
||||
@@ -2,11 +2,15 @@ import asyncio
|
||||
from datetime import datetime, time, timedelta
|
||||
import socket
|
||||
import ssl
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import urllib.parse
|
||||
import certifi
|
||||
import validators
|
||||
from collections import defaultdict
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union, Sequence, Iterator
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
|
||||
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
WebBaseLoader,
|
||||
@@ -230,6 +234,71 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
class SafeWebBaseLoader(WebBaseLoader):
|
||||
"""WebBaseLoader with enhanced error handling for URLs."""
|
||||
|
||||
def __init__(self, trust_env: bool = False, *args, **kwargs):
|
||||
"""Initialize SafeWebBaseLoader
|
||||
Args:
|
||||
trust_env (bool, optional): set to True if using proxy to make web requests, for example
|
||||
using http(s)_proxy environment variables. Defaults to False.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.trust_env = trust_env
|
||||
|
||||
async def _fetch(
|
||||
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
|
||||
) -> str:
|
||||
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
|
||||
for i in range(retries):
|
||||
try:
|
||||
kwargs: Dict = dict(
|
||||
headers=self.session.headers,
|
||||
cookies=self.session.cookies.get_dict(),
|
||||
)
|
||||
if not self.session.verify:
|
||||
kwargs["ssl"] = False
|
||||
|
||||
async with session.get(
|
||||
url, **(self.requests_kwargs | kwargs)
|
||||
) as response:
|
||||
if self.raise_for_status:
|
||||
response.raise_for_status()
|
||||
return await response.text()
|
||||
except aiohttp.ClientConnectionError as e:
|
||||
if i == retries - 1:
|
||||
raise
|
||||
else:
|
||||
log.warning(
|
||||
f"Error fetching {url} with attempt "
|
||||
f"{i + 1}/{retries}: {e}. Retrying..."
|
||||
)
|
||||
await asyncio.sleep(cooldown * backoff**i)
|
||||
raise ValueError("retry count exceeded")
|
||||
|
||||
def _unpack_fetch_results(
|
||||
self, results: Any, urls: List[str], parser: Union[str, None] = None
|
||||
) -> List[Any]:
|
||||
"""Unpack fetch results into BeautifulSoup objects."""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
final_results = []
|
||||
for i, result in enumerate(results):
|
||||
url = urls[i]
|
||||
if parser is None:
|
||||
if url.endswith(".xml"):
|
||||
parser = "xml"
|
||||
else:
|
||||
parser = self.default_parser
|
||||
self._check_parser(parser)
|
||||
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
|
||||
return final_results
|
||||
|
||||
async def ascrape_all(
|
||||
self, urls: List[str], parser: Union[str, None] = None
|
||||
) -> List[Any]:
|
||||
"""Async fetch all urls, then return soups for all results."""
|
||||
results = await self.fetch_all(urls)
|
||||
return self._unpack_fetch_results(results, urls, parser=parser)
|
||||
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load text from the url(s) in web_path with error handling."""
|
||||
for path in self.web_paths:
|
||||
@@ -245,6 +314,26 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
# Log the error and continue with the next URL
|
||||
log.exception(e, "Error loading %s", path)
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Async lazy load text from the url(s) in web_path."""
|
||||
results = await self.ascrape_all(self.web_paths)
|
||||
for path, soup in zip(self.web_paths, results):
|
||||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
metadata = {"source": path}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get(
|
||||
"content", "No description found."
|
||||
)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
|
||||
async def aload(self) -> list[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return [document async for document in self.alazy_load()]
|
||||
|
||||
RAG_WEB_LOADERS = defaultdict(lambda: SafeWebBaseLoader)
|
||||
RAG_WEB_LOADERS["playwright"] = SafePlaywrightURLLoader
|
||||
RAG_WEB_LOADERS["safe_web"] = SafeWebBaseLoader
|
||||
@@ -253,16 +342,19 @@ def get_web_loader(
|
||||
urls: Union[str, Sequence[str]],
|
||||
verify_ssl: bool = True,
|
||||
requests_per_second: int = 2,
|
||||
trust_env: bool = False,
|
||||
):
|
||||
# Check if the URLs are valid
|
||||
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||
|
||||
|
||||
web_loader_args = {
|
||||
web_path=safe_urls,
|
||||
"urls": safe_urls,
|
||||
"verify_ssl": verify_ssl,
|
||||
"requests_per_second": requests_per_second,
|
||||
"continue_on_failure": True
|
||||
"continue_on_failure": True,
|
||||
trust_env=trust_env
|
||||
}
|
||||
|
||||
if PLAYWRIGHT_WS_URI.value:
|
||||
|
||||
@@ -21,6 +21,7 @@ from fastapi import (
|
||||
APIRouter,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
import tiktoken
|
||||
|
||||
@@ -50,6 +51,7 @@ from open_webui.retrieval.web.duckduckgo import search_duckduckgo
|
||||
from open_webui.retrieval.web.google_pse import search_google_pse
|
||||
from open_webui.retrieval.web.jina_search import search_jina
|
||||
from open_webui.retrieval.web.searchapi import search_searchapi
|
||||
from open_webui.retrieval.web.serpapi import search_serpapi
|
||||
from open_webui.retrieval.web.searxng import search_searxng
|
||||
from open_webui.retrieval.web.serper import search_serper
|
||||
from open_webui.retrieval.web.serply import search_serply
|
||||
@@ -388,6 +390,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
|
||||
"searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
|
||||
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
|
||||
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
|
||||
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
|
||||
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
||||
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
@@ -439,12 +443,15 @@ class WebSearchConfig(BaseModel):
|
||||
tavily_api_key: Optional[str] = None
|
||||
searchapi_api_key: Optional[str] = None
|
||||
searchapi_engine: Optional[str] = None
|
||||
serpapi_api_key: Optional[str] = None
|
||||
serpapi_engine: Optional[str] = None
|
||||
jina_api_key: Optional[str] = None
|
||||
bing_search_v7_endpoint: Optional[str] = None
|
||||
bing_search_v7_subscription_key: Optional[str] = None
|
||||
exa_api_key: Optional[str] = None
|
||||
result_count: Optional[int] = None
|
||||
concurrent_requests: Optional[int] = None
|
||||
trust_env: Optional[bool] = None
|
||||
domain_filter_list: Optional[List[str]] = []
|
||||
|
||||
|
||||
@@ -545,6 +552,9 @@ async def update_rag_config(
|
||||
form_data.web.search.searchapi_engine
|
||||
)
|
||||
|
||||
request.app.state.config.SERPAPI_API_KEY = form_data.web.search.serpapi_api_key
|
||||
request.app.state.config.SERPAPI_ENGINE = form_data.web.search.serpapi_engine
|
||||
|
||||
request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
|
||||
request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
|
||||
form_data.web.search.bing_search_v7_endpoint
|
||||
@@ -561,6 +571,9 @@ async def update_rag_config(
|
||||
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
||||
form_data.web.search.concurrent_requests
|
||||
)
|
||||
request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = (
|
||||
form_data.web.search.trust_env
|
||||
)
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
|
||||
form_data.web.search.domain_filter_list
|
||||
)
|
||||
@@ -604,6 +617,8 @@ async def update_rag_config(
|
||||
"serply_api_key": request.app.state.config.SERPLY_API_KEY,
|
||||
"serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
|
||||
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
|
||||
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
|
||||
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
|
||||
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
|
||||
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
||||
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||
@@ -611,6 +626,7 @@ async def update_rag_config(
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
},
|
||||
},
|
||||
@@ -760,7 +776,11 @@ def save_docs_to_vector_db(
|
||||
# for meta-data so convert them to string.
|
||||
for metadata in metadatas:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, datetime):
|
||||
if (
|
||||
isinstance(value, datetime)
|
||||
or isinstance(value, list)
|
||||
or isinstance(value, dict)
|
||||
):
|
||||
metadata[key] = str(value)
|
||||
|
||||
try:
|
||||
@@ -1127,6 +1147,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
- TAVILY_API_KEY
|
||||
- EXA_API_KEY
|
||||
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
|
||||
- SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
"""
|
||||
@@ -1255,6 +1276,17 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
)
|
||||
else:
|
||||
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
|
||||
elif engine == "serpapi":
|
||||
if request.app.state.config.SERPAPI_API_KEY:
|
||||
return search_serpapi(
|
||||
request.app.state.config.SERPAPI_API_KEY,
|
||||
request.app.state.config.SERPAPI_ENGINE,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SERPAPI_API_KEY found in environment variables")
|
||||
elif engine == "jina":
|
||||
return search_jina(
|
||||
request.app.state.config.JINA_API_KEY,
|
||||
@@ -1314,17 +1346,25 @@ async def process_web_search(
|
||||
urls,
|
||||
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
)
|
||||
docs = [doc async for doc in loader.alazy_load()]
|
||||
# docs = loader.load()
|
||||
save_docs_to_vector_db(
|
||||
request, docs, collection_name, overwrite=True, user=user
|
||||
docs = await loader.aload()
|
||||
await run_in_threadpool(
|
||||
save_docs_to_vector_db,
|
||||
request,
|
||||
docs,
|
||||
collection_name,
|
||||
overwrite=True,
|
||||
user=user
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filenames": urls,
|
||||
"loaded_count": len(docs),
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
@@ -58,6 +58,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
@@ -68,6 +69,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||
class TaskConfigForm(BaseModel):
|
||||
TASK_MODEL: Optional[str]
|
||||
TASK_MODEL_EXTERNAL: Optional[str]
|
||||
ENABLE_TITLE_GENERATION: bool
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
||||
@@ -86,6 +88,7 @@ async def update_task_config(
|
||||
):
|
||||
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
||||
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
|
||||
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
@@ -122,6 +125,7 @@ async def update_task_config(
|
||||
return {
|
||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
@@ -139,7 +143,19 @@ async def update_task_config(
|
||||
async def generate_title(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
models = request.app.state.MODELS
|
||||
|
||||
if not request.app.state.config.ENABLE_TITLE_GENERATION:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"detail": "Title generation is disabled"},
|
||||
)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -198,6 +214,7 @@ async def generate_title(
|
||||
}
|
||||
),
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.TITLE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
@@ -225,7 +242,12 @@ async def generate_chat_tags(
|
||||
content={"detail": "Tags generation is disabled"},
|
||||
)
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -261,6 +283,7 @@ async def generate_chat_tags(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.TAGS_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
@@ -281,7 +304,12 @@ async def generate_chat_tags(
|
||||
async def generate_image_prompt(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -321,6 +349,7 @@ async def generate_image_prompt(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
@@ -356,7 +385,12 @@ async def generate_queries(
|
||||
detail=f"Query generation is disabled",
|
||||
)
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -392,6 +426,7 @@ async def generate_queries(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.QUERY_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
@@ -431,7 +466,12 @@ async def generate_autocompletion(
|
||||
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
||||
)
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -467,6 +507,7 @@ async def generate_autocompletion(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
@@ -488,7 +529,12 @@ async def generate_emoji(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -531,7 +577,11 @@ async def generate_emoji(
|
||||
}
|
||||
),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.EMOJI_GENERATION),
|
||||
"task_body": form_data,
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -548,7 +598,13 @@ async def generate_moa_response(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
|
||||
if model_id not in models:
|
||||
@@ -581,6 +637,7 @@ async def generate_moa_response(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": form_data.get("stream", False),
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"task": str(TASKS.MOA_RESPONSE_GENERATION),
|
||||
"task_body": form_data,
|
||||
|
||||
@@ -279,8 +279,8 @@ def get_event_emitter(request_info):
|
||||
await sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info["chat_id"],
|
||||
"message_id": request_info["message_id"],
|
||||
"chat_id": request_info.get("chat_id", None),
|
||||
"message_id": request_info.get("message_id", None),
|
||||
"data": event_data,
|
||||
},
|
||||
to=session_id,
|
||||
@@ -329,8 +329,8 @@ def get_event_call(request_info):
|
||||
response = await sio.call(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info["chat_id"],
|
||||
"message_id": request_info["message_id"],
|
||||
"chat_id": request_info.get("chat_id", None),
|
||||
"message_id": request_info.get("message_id", None),
|
||||
"data": event_data,
|
||||
},
|
||||
to=request_info["session_id"],
|
||||
|
||||
@@ -7,14 +7,17 @@ from typing import Any, Optional
|
||||
import random
|
||||
import json
|
||||
import inspect
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
from fastapi import Request, status
|
||||
from starlette.responses import Response, StreamingResponse, JSONResponse
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.socket.main import (
|
||||
sio,
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
)
|
||||
@@ -57,16 +60,127 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def generate_direct_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user: Any,
|
||||
models: dict,
|
||||
):
|
||||
print("generate_direct_chat_completion")
|
||||
|
||||
metadata = form_data.pop("metadata", {})
|
||||
|
||||
user_id = metadata.get("user_id")
|
||||
session_id = metadata.get("session_id")
|
||||
request_id = str(uuid.uuid4()) # Generate a unique request ID
|
||||
|
||||
event_caller = get_event_call(metadata)
|
||||
|
||||
channel = f"{user_id}:{session_id}:{request_id}"
|
||||
|
||||
if form_data.get("stream"):
|
||||
q = asyncio.Queue()
|
||||
|
||||
async def message_listener(sid, data):
|
||||
"""
|
||||
Handle received socket messages and push them into the queue.
|
||||
"""
|
||||
await q.put(data)
|
||||
|
||||
# Register the listener
|
||||
sio.on(channel, message_listener)
|
||||
|
||||
# Start processing chat completion in background
|
||||
res = await event_caller(
|
||||
{
|
||||
"type": "request:chat:completion",
|
||||
"data": {
|
||||
"form_data": form_data,
|
||||
"model": models[form_data["model"]],
|
||||
"channel": channel,
|
||||
"session_id": session_id,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
print("res", res)
|
||||
|
||||
if res.get("status", False):
|
||||
# Define a generator to stream responses
|
||||
async def event_generator():
|
||||
nonlocal q
|
||||
try:
|
||||
while True:
|
||||
data = await q.get() # Wait for new messages
|
||||
if isinstance(data, dict):
|
||||
if "done" in data and data["done"]:
|
||||
break # Stop streaming when 'done' is received
|
||||
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
elif isinstance(data, str):
|
||||
yield data
|
||||
except Exception as e:
|
||||
log.debug(f"Error in event generator: {e}")
|
||||
pass
|
||||
|
||||
# Define a background task to run the event generator
|
||||
async def background():
|
||||
try:
|
||||
del sio.handlers["/"][channel]
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Return the streaming response
|
||||
return StreamingResponse(
|
||||
event_generator(), media_type="text/event-stream", background=background
|
||||
)
|
||||
else:
|
||||
raise Exception(str(res))
|
||||
else:
|
||||
res = await event_caller(
|
||||
{
|
||||
"type": "request:chat:completion",
|
||||
"data": {
|
||||
"form_data": form_data,
|
||||
"model": models[form_data["model"]],
|
||||
"channel": channel,
|
||||
"session_id": session_id,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if "error" in res:
|
||||
raise Exception(res["error"])
|
||||
|
||||
return res
|
||||
|
||||
|
||||
async def generate_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user: Any,
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
log.debug(f"generate_chat_completion: {form_data}")
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if hasattr(request.state, "metadata"):
|
||||
if "metadata" not in form_data:
|
||||
form_data["metadata"] = request.state.metadata
|
||||
else:
|
||||
form_data["metadata"] = {
|
||||
**form_data["metadata"],
|
||||
**request.state.metadata,
|
||||
}
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
log.debug(f"direct connection to model: {models}")
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -80,85 +194,96 @@ async def generate_chat_completion(
|
||||
|
||||
model = models[model_id]
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if model["owned_by"] == "arena":
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||
if model_ids and filter_mode == "exclude":
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||
]
|
||||
|
||||
selected_model_id = None
|
||||
if isinstance(model_ids, list) and model_ids:
|
||||
selected_model_id = random.choice(model_ids)
|
||||
else:
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena"
|
||||
]
|
||||
selected_model_id = random.choice(model_ids)
|
||||
|
||||
form_data["model"] = selected_model_id
|
||||
|
||||
if form_data.get("stream") == True:
|
||||
|
||||
async def stream_wrapper(stream):
|
||||
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
response = await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator),
|
||||
media_type="text/event-stream",
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
**(
|
||||
await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
),
|
||||
"selected_model_id": selected_model_id,
|
||||
}
|
||||
|
||||
if model.get("pipe"):
|
||||
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||
return await generate_function_chat_completion(
|
||||
if getattr(request.state, "direct", False):
|
||||
return await generate_direct_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
if model["owned_by"] == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||
)
|
||||
if form_data.get("stream"):
|
||||
response.headers["content-type"] = "text/event-stream"
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_openai_chat_completion(
|
||||
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||
)
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if model["owned_by"] == "arena":
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||
if model_ids and filter_mode == "exclude":
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||
]
|
||||
|
||||
selected_model_id = None
|
||||
if isinstance(model_ids, list) and model_ids:
|
||||
selected_model_id = random.choice(model_ids)
|
||||
else:
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena"
|
||||
]
|
||||
selected_model_id = random.choice(model_ids)
|
||||
|
||||
form_data["model"] = selected_model_id
|
||||
|
||||
if form_data.get("stream") == True:
|
||||
|
||||
async def stream_wrapper(stream):
|
||||
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
response = await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator),
|
||||
media_type="text/event-stream",
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
**(
|
||||
await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
),
|
||||
"selected_model_id": selected_model_id,
|
||||
}
|
||||
|
||||
if model.get("pipe"):
|
||||
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||
return await generate_function_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
if model["owned_by"] == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
bypass_filter=bypass_filter,
|
||||
)
|
||||
if form_data.get("stream"):
|
||||
response.headers["content-type"] = "text/event-stream"
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_openai_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
bypass_filter=bypass_filter,
|
||||
)
|
||||
|
||||
|
||||
chat_completion = generate_chat_completion
|
||||
@@ -167,7 +292,13 @@ chat_completion = generate_chat_completion
|
||||
async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
models = request.app.state.MODELS
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
@@ -227,7 +358,13 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
models = request.app.state.MODELS
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
|
||||
@@ -616,7 +616,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
task_model_id = get_task_model_id(
|
||||
form_data["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
@@ -766,17 +772,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
if "document" in source:
|
||||
for doc_idx, doc_context in enumerate(source["document"]):
|
||||
doc_metadata = source.get("metadata")
|
||||
doc_source_id = None
|
||||
|
||||
if doc_metadata:
|
||||
doc_source_id = doc_metadata[doc_idx].get("source", source_id)
|
||||
|
||||
if source_id:
|
||||
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
else:
|
||||
# If there is no source_id, then do not include the source_id tag
|
||||
context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
|
||||
context_string += f"<source><source_id>{doc_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
|
||||
context_string = context_string.strip()
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
@@ -1149,6 +1145,46 @@ async def process_chat_response(
|
||||
|
||||
return content.strip()
|
||||
|
||||
def convert_content_blocks_to_messages(content_blocks):
|
||||
messages = []
|
||||
|
||||
temp_blocks = []
|
||||
for idx, block in enumerate(content_blocks):
|
||||
if block["type"] == "tool_calls":
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": serialize_content_blocks(temp_blocks),
|
||||
"tool_calls": block.get("content"),
|
||||
}
|
||||
)
|
||||
|
||||
results = block.get("results", [])
|
||||
|
||||
for result in results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result["tool_call_id"],
|
||||
"content": result["content"],
|
||||
}
|
||||
)
|
||||
temp_blocks = []
|
||||
else:
|
||||
temp_blocks.append(block)
|
||||
|
||||
if temp_blocks:
|
||||
content = serialize_content_blocks(temp_blocks)
|
||||
if content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def tag_content_handler(content_type, tags, content, content_blocks):
|
||||
end_flag = False
|
||||
|
||||
@@ -1540,7 +1576,6 @@ async def process_chat_response(
|
||||
|
||||
results = []
|
||||
for tool_call in response_tool_calls:
|
||||
print("\n\n" + str(tool_call) + "\n\n")
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("function", {}).get("name", "")
|
||||
|
||||
@@ -1606,23 +1641,10 @@ async def process_chat_response(
|
||||
{
|
||||
"model": model_id,
|
||||
"stream": True,
|
||||
"tools": form_data["tools"],
|
||||
"messages": [
|
||||
*form_data["messages"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks, raw=True
|
||||
),
|
||||
"tool_calls": response_tool_calls,
|
||||
},
|
||||
*[
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result["tool_call_id"],
|
||||
"content": result["content"],
|
||||
}
|
||||
for result in results
|
||||
],
|
||||
*convert_content_blocks_to_messages(content_blocks),
|
||||
],
|
||||
},
|
||||
user,
|
||||
@@ -1671,6 +1693,9 @@ async def process_chat_response(
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": code,
|
||||
"session_id": metadata.get(
|
||||
"session_id", None
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1699,10 +1724,12 @@ async def process_chat_response(
|
||||
"stdout": "Code interpreter engine not configured."
|
||||
}
|
||||
|
||||
log.debug(f"Code interpreter output: {output}")
|
||||
|
||||
if isinstance(output, dict):
|
||||
stdout = output.get("stdout", "")
|
||||
|
||||
if stdout:
|
||||
if isinstance(stdout, str):
|
||||
stdoutLines = stdout.split("\n")
|
||||
for idx, line in enumerate(stdoutLines):
|
||||
if "data:image/png;base64" in line:
|
||||
@@ -1734,7 +1761,7 @@ async def process_chat_response(
|
||||
|
||||
result = output.get("result", "")
|
||||
|
||||
if result:
|
||||
if isinstance(result, str):
|
||||
resultLines = result.split("\n")
|
||||
for idx, line in enumerate(resultLines):
|
||||
if "data:image/png;base64" in line:
|
||||
@@ -1784,6 +1811,8 @@ async def process_chat_response(
|
||||
}
|
||||
)
|
||||
|
||||
print(content_blocks, serialize_content_blocks(content_blocks))
|
||||
|
||||
try:
|
||||
res = await generate_chat_completion(
|
||||
request,
|
||||
|
||||
@@ -217,12 +217,20 @@ def openai_chat_chunk_message_template(
|
||||
|
||||
|
||||
def openai_chat_completion_message_template(
|
||||
model: str, message: Optional[str] = None, usage: Optional[dict] = None
|
||||
model: str,
|
||||
message: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict]] = None,
|
||||
usage: Optional[dict] = None,
|
||||
) -> dict:
|
||||
template = openai_chat_message_template(model)
|
||||
template["object"] = "chat.completion"
|
||||
if message is not None:
|
||||
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
|
||||
template["choices"][0]["message"] = {
|
||||
"content": message,
|
||||
"role": "assistant",
|
||||
**({"tool_calls": tool_calls} if tool_calls else {}),
|
||||
}
|
||||
|
||||
template["choices"][0]["finish_reason"] = "stop"
|
||||
|
||||
if usage:
|
||||
|
||||
@@ -6,9 +6,32 @@ from open_webui.utils.misc import (
|
||||
)
|
||||
|
||||
|
||||
def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict:
|
||||
openai_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
openai_tool_call = {
|
||||
"index": tool_call.get("index", 0),
|
||||
"id": tool_call.get("id", f"call_{str(uuid4())}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("function", {}).get("name", ""),
|
||||
"arguments": json.dumps(
|
||||
tool_call.get("function", {}).get("arguments", {})
|
||||
),
|
||||
},
|
||||
}
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
return openai_tool_calls
|
||||
|
||||
|
||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
model = ollama_response.get("model", "ollama")
|
||||
message_content = ollama_response.get("message", {}).get("content", "")
|
||||
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
if tool_calls:
|
||||
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
|
||||
|
||||
data = ollama_response
|
||||
usage = {
|
||||
@@ -51,7 +74,9 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
),
|
||||
}
|
||||
|
||||
response = openai_chat_completion_message_template(model, message_content, usage)
|
||||
response = openai_chat_completion_message_template(
|
||||
model, message_content, openai_tool_calls, usage
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@@ -65,20 +90,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
openai_tool_calls = None
|
||||
|
||||
if tool_calls:
|
||||
openai_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
openai_tool_call = {
|
||||
"index": tool_call.get("index", 0),
|
||||
"id": tool_call.get("id", f"call_{str(uuid4())}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("function", {}).get("name", ""),
|
||||
"arguments": json.dumps(
|
||||
tool_call.get("function", {}).get("arguments", {})
|
||||
),
|
||||
},
|
||||
}
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
|
||||
|
||||
done = data.get("done", False)
|
||||
|
||||
|
||||
@@ -3,9 +3,6 @@ uvicorn[standard]==0.30.6
|
||||
pydantic==2.9.2
|
||||
python-multipart==0.0.18
|
||||
|
||||
Flask==3.1.0
|
||||
Flask-Cors==5.0.0
|
||||
|
||||
python-socketio==5.11.3
|
||||
python-jose==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
|
||||
Reference in New Issue
Block a user