refac: citations -> sources

This commit is contained in:
Timothy Jaeryang Baek
2024-11-21 19:46:09 -08:00
parent 7062e637e8
commit 81386e9b04
9 changed files with 165 additions and 126 deletions

View File

@@ -49,7 +49,7 @@ from open_webui.apps.openai.main import (
get_all_models_responses as get_openai_models_responses,
)
from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
from open_webui.apps.retrieval.utils import get_sources_from_files, rag_template
from open_webui.apps.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
@@ -380,8 +380,7 @@ async def chat_completion_tools_handler(
return body, {}
skip_files = False
contexts = []
citations = []
sources = []
task_model_id = get_task_model_id(
body["model"],
@@ -465,24 +464,37 @@ async def chat_completion_tools_handler(
print(tools[tool_function_name]["citation"])
if tools[tool_function_name]["citation"]:
citations.append(
{
"source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
},
"document": [tool_output],
"metadata": [{"source": tool_function_name}],
}
)
else:
citations.append({})
if tools[tool_function_name]["file_handler"]:
skip_files = True
if isinstance(tool_output, str):
contexts.append(tool_output)
if tools[tool_function_name]["citation"]:
sources.append(
{
"source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
},
"document": [tool_output],
"metadata": [
{
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
}
],
}
)
else:
sources.append(
{
"source": {},
"document": [tool_output],
"metadata": [
{
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
}
],
}
)
if tools[tool_function_name]["file_handler"]:
skip_files = True
except Exception as e:
log.exception(f"Error: {e}")
content = None
@@ -490,19 +502,18 @@ async def chat_completion_tools_handler(
log.exception(f"Error: {e}")
content = None
log.debug(f"tool_contexts: {contexts} {citations}")
log.debug(f"tool_contexts: {sources}")
if skip_files and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
return body, {"contexts": contexts, "citations": citations}
return body, {"sources": sources}
async def chat_completion_files_handler(
body: dict, user: UserModel
) -> tuple[dict, dict[str, list]]:
contexts = []
citations = []
sources = []
try:
queries_response = await generate_queries(
@@ -530,7 +541,7 @@ async def chat_completion_files_handler(
print(f"{queries=}")
if files := body.get("metadata", {}).get("files", None):
contexts, citations = get_rag_context(
sources = get_sources_from_files(
files=files,
queries=queries,
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
@@ -540,9 +551,8 @@ async def chat_completion_files_handler(
hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
return body, {"contexts": contexts, "citations": citations}
log.debug(f"rag_contexts:sources: {sources}")
return body, {"sources": sources}
def is_chat_completion_request(request):
@@ -643,8 +653,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Initialize data_items to store additional data to be sent to the client
# Initialize contexts and citation
data_items = []
contexts = []
citations = []
sources = []
try:
body, flags = await chat_completion_filter_functions_handler(
@@ -670,32 +679,34 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
body, flags = await chat_completion_tools_handler(
body, user, models, extra_params
)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
try:
body, flags = await chat_completion_files_handler(body, user)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
# If context is not empty, insert it into the messages
if len(contexts) > 0:
if len(sources) > 0:
context_string = ""
for context_idx, context in enumerate(contexts):
print(context)
source_id = citations[context_idx].get("source", {}).get("name", "")
for source_idx, source in enumerate(sources):
source_id = source.get("source", {}).get("name", "")
print(f"\n\n\n\n{source_id}\n\n\n\n")
if source_id:
context_string += f"<source><source_id>{source_id}</source_id><source_context>{context}</source_context></source>\n"
else:
context_string += (
f"<source><source_context>{context}</source_context></source>\n"
)
if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]):
metadata = source.get("metadata")
if metadata:
doc_source_id = metadata[doc_idx].get("source", source_id)
if source_id:
context_string += f"<source><source_id>{doc_source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
else:
# If there is no source_id, then do not include the source_id tag
context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
context_string = context_string.strip()
prompt = get_last_user_message(body["messages"])
@@ -728,8 +739,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
# If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations})
sources = [
source for source in sources if source.get("source", {}).get("name", "")
]
if len(sources) > 0:
data_items.append({"sources": sources})
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one