mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-15 19:37:47 +01:00
feat: show RAG query results as citations
This commit is contained in:
committed by
Timothy J. Baek
parent
ba09fcd548
commit
0872bea790
@@ -320,11 +320,19 @@ def rag_messages(
|
||||
extracted_collections.extend(collection)
|
||||
|
||||
context_string = ""
|
||||
citations = []
|
||||
for context in relevant_contexts:
|
||||
try:
|
||||
if "documents" in context:
|
||||
items = [item for item in context["documents"][0] if item is not None]
|
||||
context_string += "\n\n".join(items)
|
||||
if "metadatas" in context:
|
||||
citations.append(
|
||||
{
|
||||
"document": context["documents"][0],
|
||||
"metadata": context["metadatas"][0],
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
context_string = context_string.strip()
|
||||
@@ -355,7 +363,7 @@ def rag_messages(
|
||||
|
||||
messages[last_user_message_idx] = new_user_message
|
||||
|
||||
return messages
|
||||
return messages, citations
|
||||
|
||||
|
||||
def get_model_path(model: str, update_model: bool = False):
|
||||
|
||||
@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from apps.ollama.main import app as ollama_app
|
||||
from apps.openai.main import app as openai_app
|
||||
@@ -102,6 +102,8 @@ origins = ["*"]
|
||||
|
||||
class RAGMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
return_citations = False
|
||||
|
||||
if request.method == "POST" and (
|
||||
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
||||
):
|
||||
@@ -114,11 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
return_citations = data.get("citations", False)
|
||||
if "citations" in data:
|
||||
del data["citations"]
|
||||
|
||||
# Example: Add a new key-value pair or modify existing ones
|
||||
# data["modified"] = True # Example modification
|
||||
if "docs" in data:
|
||||
data = {**data}
|
||||
data["messages"] = rag_messages(
|
||||
data["messages"], citations = rag_messages(
|
||||
docs=data["docs"],
|
||||
messages=data["messages"],
|
||||
template=rag_app.state.RAG_TEMPLATE,
|
||||
@@ -130,7 +136,9 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
del data["docs"]
|
||||
|
||||
log.debug(f"data['messages']: {data['messages']}")
|
||||
log.debug(
|
||||
f"data['messages']: {data['messages']}, citations: {citations}"
|
||||
)
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
|
||||
@@ -148,11 +156,36 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
]
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
if return_citations:
|
||||
# Inject the citations into the response
|
||||
if isinstance(response, StreamingResponse):
|
||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if "text/event-stream" in content_type:
|
||||
return StreamingResponse(
|
||||
self.openai_stream_wrapper(response.body_iterator, citations),
|
||||
)
|
||||
if "application/x-ndjson" in content_type:
|
||||
return StreamingResponse(
|
||||
self.ollama_stream_wrapper(response.body_iterator, citations),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def _receive(self, body: bytes):
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
async def openai_stream_wrapper(self, original_generator, citations):
|
||||
yield f"data: {json.dumps({'citations': citations})}\n\n"
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
async def ollama_stream_wrapper(self, original_generator, citations):
|
||||
yield f"{json.dumps({'citations': citations})}\n"
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
|
||||
app.add_middleware(RAGMiddleware)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user