refac: web search

This commit is contained in:
Timothy Jaeryang Baek
2024-12-24 17:52:57 -07:00
parent a074991d3a
commit 6b25139d4f
5 changed files with 179 additions and 100 deletions

View File

@@ -856,6 +856,7 @@ async def chat_completion(
"session_id": form_data.pop("session_id", None),
"tool_ids": form_data.get("tool_ids", None),
"files": form_data.get("files", None),
"features": form_data.get("features", None),
}
form_data["metadata"] = metadata

View File

@@ -1238,7 +1238,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
@router.post("/process/web/search")
def process_web_search(
async def process_web_search(
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
):
try:
@@ -1256,9 +1256,11 @@ def process_web_search(
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
)
log.debug(f"web_results: {web_results}")
try:
collection_name = form_data.collection_name
if collection_name == "":
if collection_name == "" or collection_name is None:
collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[
:63
]
@@ -1269,8 +1271,7 @@ def process_web_search(
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
)
docs = loader.aload()
docs = loader.load()
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
return {

View File

@@ -29,6 +29,7 @@ from open_webui.routers.tasks import (
generate_title,
generate_chat_tags,
)
from open_webui.routers.retrieval import process_web_search, SearchForm
from open_webui.utils.webhook import post_webhook
@@ -333,6 +334,149 @@ async def chat_completion_tools_handler(
return body, {"sources": sources}
async def chat_web_search_handler(
request: Request, form_data: dict, extra_params: dict, user
):
event_emitter = extra_params["__event_emitter__"]
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Generating search query",
"done": False,
},
}
)
messages = form_data["messages"]
user_message = get_last_user_message(messages)
queries = []
try:
res = await generate_queries(
request,
{
"model": form_data["model"],
"messages": messages,
"prompt": user_message,
"type": "web_search",
},
user,
)
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
queries = json.loads(response)
queries = queries.get("queries", [])
except Exception as e:
queries = [response]
except Exception as e:
log.exception(e)
queries = [user_message]
if len(queries) == 0:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search query generated",
"done": True,
},
}
)
return
searchQuery = queries[0]
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Searching "{{searchQuery}}"',
"query": searchQuery,
"done": False,
},
}
)
try:
results = await process_web_search(
request,
SearchForm(
**{
"query": searchQuery,
}
),
user,
)
if results:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Searched {{count}} sites",
"query": searchQuery,
"urls": results["filenames"],
"done": True,
},
}
)
files = form_data.get("files", [])
files.append(
{
"collection_name": results["collection_name"],
"name": searchQuery,
"type": "web_search_results",
"urls": results["filenames"],
}
)
form_data["files"] = files
else:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"query": searchQuery,
"done": True,
"error": True,
},
}
)
except Exception as e:
log.exception(e)
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Error searching "{{searchQuery}}"',
"query": searchQuery,
"done": True,
"error": True,
},
}
)
return form_data
async def chat_completion_files_handler(
request: Request, body: dict, user: UserModel
) -> tuple[dict, dict[str, list]]:
@@ -456,7 +600,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
knowledge_files = []
for item in model_knowledge:
print(item)
if item.get("collection_name"):
knowledge_files.append(
{
@@ -481,6 +624,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
files.extend(knowledge_files)
form_data["files"] = files
features = form_data.pop("features", None)
if features:
if "web_search" in features and features["web_search"]:
form_data = await chat_web_search_handler(
request, form_data, extra_params, user
)
try:
form_data, flags = await chat_completion_filter_functions_handler(
request, form_data, model, extra_params