Merge remote-tracking branch 'upstream/dev' into playwright

# Conflicts:
#	backend/open_webui/config.py
#	backend/open_webui/main.py
#	backend/open_webui/retrieval/web/utils.py
#	backend/open_webui/routers/retrieval.py
#	backend/open_webui/utils/middleware.py
#	pyproject.toml
This commit is contained in:
Rory
2025-02-14 20:48:22 -06:00
92 changed files with 2583 additions and 454 deletions

View File

@@ -21,6 +21,7 @@ from fastapi import (
APIRouter,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
import tiktoken
@@ -50,6 +51,7 @@ from open_webui.retrieval.web.duckduckgo import search_duckduckgo
from open_webui.retrieval.web.google_pse import search_google_pse
from open_webui.retrieval.web.jina_search import search_jina
from open_webui.retrieval.web.searchapi import search_searchapi
from open_webui.retrieval.web.serpapi import search_serpapi
from open_webui.retrieval.web.searxng import search_searxng
from open_webui.retrieval.web.serper import search_serper
from open_webui.retrieval.web.serply import search_serply
@@ -388,6 +390,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
"searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
"jina_api_key": request.app.state.config.JINA_API_KEY,
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
@@ -439,12 +443,15 @@ class WebSearchConfig(BaseModel):
tavily_api_key: Optional[str] = None
searchapi_api_key: Optional[str] = None
searchapi_engine: Optional[str] = None
serpapi_api_key: Optional[str] = None
serpapi_engine: Optional[str] = None
jina_api_key: Optional[str] = None
bing_search_v7_endpoint: Optional[str] = None
bing_search_v7_subscription_key: Optional[str] = None
exa_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
trust_env: Optional[bool] = None
domain_filter_list: Optional[List[str]] = []
@@ -545,6 +552,9 @@ async def update_rag_config(
form_data.web.search.searchapi_engine
)
request.app.state.config.SERPAPI_API_KEY = form_data.web.search.serpapi_api_key
request.app.state.config.SERPAPI_ENGINE = form_data.web.search.serpapi_engine
request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
form_data.web.search.bing_search_v7_endpoint
@@ -561,6 +571,9 @@ async def update_rag_config(
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
)
request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = (
form_data.web.search.trust_env
)
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
form_data.web.search.domain_filter_list
)
@@ -604,6 +617,8 @@ async def update_rag_config(
"serply_api_key": request.app.state.config.SERPLY_API_KEY,
"serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
"jina_api_key": request.app.state.config.JINA_API_KEY,
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
@@ -611,6 +626,7 @@ async def update_rag_config(
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
},
@@ -760,7 +776,11 @@ def save_docs_to_vector_db(
# for meta-data so convert them to string.
for metadata in metadatas:
for key, value in metadata.items():
if isinstance(value, datetime):
if (
isinstance(value, datetime)
or isinstance(value, list)
or isinstance(value, dict)
):
metadata[key] = str(value)
try:
@@ -1127,6 +1147,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
- TAVILY_API_KEY
- EXA_API_KEY
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
- SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
Args:
query (str): The query to search for
"""
@@ -1255,6 +1276,17 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
elif engine == "serpapi":
if request.app.state.config.SERPAPI_API_KEY:
return search_serpapi(
request.app.state.config.SERPAPI_API_KEY,
request.app.state.config.SERPAPI_ENGINE,
query,
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SERPAPI_API_KEY found in environment variables")
elif engine == "jina":
return search_jina(
request.app.state.config.JINA_API_KEY,
@@ -1314,17 +1346,25 @@ async def process_web_search(
urls,
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
)
docs = [doc async for doc in loader.alazy_load()]
# docs = loader.load()
save_docs_to_vector_db(
request, docs, collection_name, overwrite=True, user=user
docs = await loader.aload()
await run_in_threadpool(
save_docs_to_vector_db,
request,
docs,
collection_name,
overwrite=True,
user=user
)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
"loaded_count": len(docs),
}
except Exception as e:
log.exception(e)

View File

@@ -58,6 +58,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
@@ -68,6 +69,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
ENABLE_TITLE_GENERATION: bool
TITLE_GENERATION_PROMPT_TEMPLATE: str
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
ENABLE_AUTOCOMPLETE_GENERATION: bool
@@ -86,6 +88,7 @@ async def update_task_config(
):
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
@@ -122,6 +125,7 @@ async def update_task_config(
return {
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
@@ -139,7 +143,19 @@ async def update_task_config(
async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if not request.app.state.config.ENABLE_TITLE_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Title generation is disabled"},
)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -198,6 +214,7 @@ async def generate_title(
}
),
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.TITLE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -225,7 +242,12 @@ async def generate_chat_tags(
content={"detail": "Tags generation is disabled"},
)
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -261,6 +283,7 @@ async def generate_chat_tags(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.TAGS_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -281,7 +304,12 @@ async def generate_chat_tags(
async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -321,6 +349,7 @@ async def generate_image_prompt(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -356,7 +385,12 @@ async def generate_queries(
detail=f"Query generation is disabled",
)
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -392,6 +426,7 @@ async def generate_queries(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.QUERY_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -431,7 +466,12 @@ async def generate_autocompletion(
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
)
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -467,6 +507,7 @@ async def generate_autocompletion(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -488,7 +529,12 @@ async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -531,7 +577,11 @@ async def generate_emoji(
}
),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.EMOJI_GENERATION),
"task_body": form_data,
},
}
try:
@@ -548,7 +598,13 @@ async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -581,6 +637,7 @@ async def generate_moa_response(
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"chat_id": form_data.get("chat_id", None),
"task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data,