diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py
index 10d6ff9a70..ba53a1c895 100644
--- a/backend/open_webui/apps/retrieval/main.py
+++ b/backend/open_webui/apps/retrieval/main.py
@@ -902,10 +902,11 @@ def process_file(
Document(
page_content=form_data.content,
metadata={
- "name": file.meta.get("name", file.filename),
+ **file.meta,
+ "name": file.filename,
"created_by": file.user_id,
"file_id": file.id,
- **file.meta,
+ "source": file.filename,
},
)
]
@@ -932,10 +933,11 @@ def process_file(
Document(
page_content=file.data.get("content", ""),
metadata={
- "name": file.meta.get("name", file.filename),
+ **file.meta,
+ "name": file.filename,
"created_by": file.user_id,
"file_id": file.id,
- **file.meta,
+ "source": file.filename,
},
)
]
@@ -955,15 +957,30 @@ def process_file(
docs = loader.load(
file.filename, file.meta.get("content_type"), file_path
)
+
+ docs = [
+ Document(
+ page_content=doc.page_content,
+ metadata={
+ **doc.metadata,
+ "name": file.filename,
+ "created_by": file.user_id,
+ "file_id": file.id,
+ "source": file.filename,
+ },
+ )
+ for doc in docs
+ ]
else:
docs = [
Document(
page_content=file.data.get("content", ""),
metadata={
+ **file.meta,
"name": file.filename,
"created_by": file.user_id,
"file_id": file.id,
- **file.meta,
+ "source": file.filename,
},
)
]
diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py
index 6d87c98e36..e4e36fbfdb 100644
--- a/backend/open_webui/apps/retrieval/utils.py
+++ b/backend/open_webui/apps/retrieval/utils.py
@@ -307,7 +307,7 @@ def get_embedding_function(
return lambda query: generate_multiple(query, func)
-def get_rag_context(
+def get_sources_from_files(
files,
queries,
embedding_function,
@@ -387,43 +387,24 @@ def get_rag_context(
del file["data"]
relevant_contexts.append({**context, "file": file})
- contexts = []
- citations = []
+ sources = []
for context in relevant_contexts:
try:
if "documents" in context:
- file_names = list(
- set(
- [
- metadata["name"]
- for metadata in context["metadatas"][0]
- if metadata is not None and "name" in metadata
- ]
- )
- )
- contexts.append(
- ((", ".join(file_names) + ":\n\n") if file_names else "")
- + "\n\n".join(
- [text for text in context["documents"][0] if text is not None]
- )
- )
-
if "metadatas" in context:
- citation = {
+ source = {
"source": context["file"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
if "distances" in context and context["distances"]:
- citation["distances"] = context["distances"][0]
- citations.append(citation)
+ source["distances"] = context["distances"][0]
+
+ sources.append(source)
except Exception as e:
log.exception(e)
- print("contexts", contexts)
- print("citations", citations)
-
- return contexts, citations
+ return sources
def get_model_path(model: str, update_model: bool = False):
diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py
index b8695eb672..e7459a15f2 100644
--- a/backend/open_webui/apps/webui/routers/files.py
+++ b/backend/open_webui/apps/webui/routers/files.py
@@ -56,7 +56,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
FileForm(
**{
"id": id,
- "filename": filename,
+ "filename": name,
"path": file_path,
"meta": {
"name": name,
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index d62ef158ec..797c9622a4 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -49,7 +49,7 @@ from open_webui.apps.openai.main import (
get_all_models_responses as get_openai_models_responses,
)
from open_webui.apps.retrieval.main import app as retrieval_app
-from open_webui.apps.retrieval.utils import get_rag_context, rag_template
+from open_webui.apps.retrieval.utils import get_sources_from_files, rag_template
from open_webui.apps.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
@@ -380,8 +380,7 @@ async def chat_completion_tools_handler(
return body, {}
skip_files = False
- contexts = []
- citations = []
+ sources = []
task_model_id = get_task_model_id(
body["model"],
@@ -465,24 +464,37 @@ async def chat_completion_tools_handler(
print(tools[tool_function_name]["citation"])
- if tools[tool_function_name]["citation"]:
- citations.append(
- {
- "source": {
- "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
- },
- "document": [tool_output],
- "metadata": [{"source": tool_function_name}],
- }
- )
- else:
- citations.append({})
-
- if tools[tool_function_name]["file_handler"]:
- skip_files = True
-
if isinstance(tool_output, str):
- contexts.append(tool_output)
+ if tools[tool_function_name]["citation"]:
+ sources.append(
+ {
+ "source": {
+ "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
+ },
+ "document": [tool_output],
+ "metadata": [
+ {
+ "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
+ }
+ ],
+ }
+ )
+ else:
+ sources.append(
+ {
+ "source": {},
+ "document": [tool_output],
+ "metadata": [
+ {
+ "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
+ }
+ ],
+ }
+ )
+
+ if tools[tool_function_name]["file_handler"]:
+ skip_files = True
+
except Exception as e:
log.exception(f"Error: {e}")
content = None
@@ -490,19 +502,18 @@ async def chat_completion_tools_handler(
log.exception(f"Error: {e}")
content = None
- log.debug(f"tool_contexts: {contexts} {citations}")
+ log.debug(f"tool_contexts: {sources}")
if skip_files and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
- return body, {"contexts": contexts, "citations": citations}
+ return body, {"sources": sources}
async def chat_completion_files_handler(
body: dict, user: UserModel
) -> tuple[dict, dict[str, list]]:
- contexts = []
- citations = []
+ sources = []
try:
queries_response = await generate_queries(
@@ -530,7 +541,7 @@ async def chat_completion_files_handler(
print(f"{queries=}")
if files := body.get("metadata", {}).get("files", None):
- contexts, citations = get_rag_context(
+ sources = get_sources_from_files(
files=files,
queries=queries,
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
@@ -540,9 +551,8 @@ async def chat_completion_files_handler(
hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
- log.debug(f"rag_contexts: {contexts}, citations: {citations}")
-
- return body, {"contexts": contexts, "citations": citations}
+ log.debug(f"rag_contexts:sources: {sources}")
+ return body, {"sources": sources}
def is_chat_completion_request(request):
@@ -643,8 +653,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Initialize data_items to store additional data to be sent to the client
# Initialize contexts and citation
data_items = []
- contexts = []
- citations = []
+ sources = []
try:
body, flags = await chat_completion_filter_functions_handler(
@@ -670,32 +679,34 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
body, flags = await chat_completion_tools_handler(
body, user, models, extra_params
)
- contexts.extend(flags.get("contexts", []))
- citations.extend(flags.get("citations", []))
+ sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
try:
body, flags = await chat_completion_files_handler(body, user)
- contexts.extend(flags.get("contexts", []))
- citations.extend(flags.get("citations", []))
+ sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
# If context is not empty, insert it into the messages
- if len(contexts) > 0:
+ if len(sources) > 0:
context_string = ""
- for context_idx, context in enumerate(contexts):
- print(context)
- source_id = citations[context_idx].get("source", {}).get("name", "")
+ for source_idx, source in enumerate(sources):
+ source_id = source.get("source", {}).get("name", "")
- print(f"\n\n\n\n{source_id}\n\n\n\n")
- if source_id:
- context_string += f"