enh: bypass embedding and retrieval

This commit is contained in:
Timothy Jaeryang Baek
2025-02-26 15:42:19 -08:00
parent 1c2e36f1b7
commit 57010901e6
10 changed files with 486 additions and 370 deletions

View File

@@ -17,6 +17,7 @@ from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
from open_webui.models.users import UserModel
from open_webui.models.files import Files
from open_webui.env import (
SRC_LOG_LEVELS,
@@ -342,6 +343,7 @@ def get_embedding_function(
def get_sources_from_files(
request,
files,
queries,
embedding_function,
@@ -359,19 +361,64 @@ def get_sources_from_files(
relevant_contexts = []
for file in files:
context = None
if file.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
context = {
"documents": [[doc.get("content") for doc in file.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
}
elif file.get("context") == "full":
# Manual Full Mode Toggle
context = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
}
else:
context = None
elif (
file.get("type") != "web_search"
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
# BYPASS_EMBEDDING_AND_RETRIEVAL
if file.get("type") == "collection":
file_ids = file.get("data", {}).get("file_ids", [])
documents = []
metadatas = []
for file_id in file_ids:
file_object = Files.get_file_by_id(file_id)
if file_object:
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
context = {
"documents": [documents],
"metadatas": [metadatas],
}
elif file.get("id"):
file_object = Files.get_file_by_id(file.get("id"))
if file_object:
context = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": file.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
else:
collection_names = []
if file.get("type") == "collection":
if file.get("legacy"):
@@ -434,6 +481,7 @@ def get_sources_from_files(
if context:
if "data" in file:
del file["data"]
relevant_contexts.append({**context, "file": file})
sources = []