mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
refac: web search
This commit is contained in:
@@ -856,6 +856,7 @@ async def chat_completion(
|
||||
"session_id": form_data.pop("session_id", None),
|
||||
"tool_ids": form_data.get("tool_ids", None),
|
||||
"files": form_data.get("files", None),
|
||||
"features": form_data.get("features", None),
|
||||
}
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
|
||||
@@ -1238,7 +1238,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
|
||||
|
||||
@router.post("/process/web/search")
|
||||
def process_web_search(
|
||||
async def process_web_search(
|
||||
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
@@ -1256,9 +1256,11 @@ def process_web_search(
|
||||
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
||||
)
|
||||
|
||||
log.debug(f"web_results: {web_results}")
|
||||
|
||||
try:
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name == "":
|
||||
if collection_name == "" or collection_name is None:
|
||||
collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[
|
||||
:63
|
||||
]
|
||||
@@ -1269,8 +1271,7 @@ def process_web_search(
|
||||
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
)
|
||||
docs = loader.aload()
|
||||
|
||||
docs = loader.load()
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
|
||||
return {
|
||||
|
||||
@@ -29,6 +29,7 @@ from open_webui.routers.tasks import (
|
||||
generate_title,
|
||||
generate_chat_tags,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
|
||||
@@ -333,6 +334,149 @@ async def chat_completion_tools_handler(
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
async def chat_web_search_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
event_emitter = extra_params["__event_emitter__"]
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "Generating search query",
|
||||
"done": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
messages = form_data["messages"]
|
||||
user_message = get_last_user_message(messages)
|
||||
|
||||
queries = []
|
||||
try:
|
||||
res = await generate_queries(
|
||||
request,
|
||||
{
|
||||
"model": form_data["model"],
|
||||
"messages": messages,
|
||||
"prompt": user_message,
|
||||
"type": "web_search",
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
response = res["choices"][0]["message"]["content"]
|
||||
|
||||
try:
|
||||
bracket_start = response.find("{")
|
||||
bracket_end = response.rfind("}") + 1
|
||||
|
||||
if bracket_start == -1 or bracket_end == -1:
|
||||
raise Exception("No JSON object found in the response")
|
||||
|
||||
response = response[bracket_start:bracket_end]
|
||||
queries = json.loads(response)
|
||||
queries = queries.get("queries", [])
|
||||
except Exception as e:
|
||||
queries = [response]
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
queries = [user_message]
|
||||
|
||||
if len(queries) == 0:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "No search query generated",
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
searchQuery = queries[0]
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": 'Searching "{{searchQuery}}"',
|
||||
"query": searchQuery,
|
||||
"done": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
results = await process_web_search(
|
||||
request,
|
||||
SearchForm(
|
||||
**{
|
||||
"query": searchQuery,
|
||||
}
|
||||
),
|
||||
user,
|
||||
)
|
||||
|
||||
if results:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "Searched {{count}} sites",
|
||||
"query": searchQuery,
|
||||
"urls": results["filenames"],
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
files = form_data.get("files", [])
|
||||
files.append(
|
||||
{
|
||||
"collection_name": results["collection_name"],
|
||||
"name": searchQuery,
|
||||
"type": "web_search_results",
|
||||
"urls": results["filenames"],
|
||||
}
|
||||
)
|
||||
form_data["files"] = files
|
||||
else:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "No search results found",
|
||||
"query": searchQuery,
|
||||
"done": True,
|
||||
"error": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": 'Error searching "{{searchQuery}}"',
|
||||
"query": searchQuery,
|
||||
"done": True,
|
||||
"error": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
async def chat_completion_files_handler(
|
||||
request: Request, body: dict, user: UserModel
|
||||
) -> tuple[dict, dict[str, list]]:
|
||||
@@ -456,7 +600,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
knowledge_files = []
|
||||
for item in model_knowledge:
|
||||
print(item)
|
||||
if item.get("collection_name"):
|
||||
knowledge_files.append(
|
||||
{
|
||||
@@ -481,6 +624,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
files.extend(knowledge_files)
|
||||
form_data["files"] = files
|
||||
|
||||
features = form_data.pop("features", None)
|
||||
if features:
|
||||
if "web_search" in features and features["web_search"]:
|
||||
form_data = await chat_web_search_handler(
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_filter_functions_handler(
|
||||
request, form_data, model, extra_params
|
||||
|
||||
Reference in New Issue
Block a user