mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
refac: "rag" endpoints renamed to "retrieval"
This commit is contained in:
524
backend/open_webui/apps/retrieval/utils.py
Normal file
524
backend/open_webui/apps/retrieval/utils.py
Normal file
@@ -0,0 +1,524 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.apps.ollama.main import (
|
||||
GenerateEmbeddingsForm,
|
||||
generate_ollama_embeddings,
|
||||
)
|
||||
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class VectorSearchRetriever(BaseRetriever):
|
||||
collection_name: Any
|
||||
embedding_function: Any
|
||||
top_k: int
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
collection_name=self.collection_name,
|
||||
vectors=[self.embedding_function(query)],
|
||||
limit=self.top_k,
|
||||
)
|
||||
|
||||
ids = result.ids[0]
|
||||
metadatas = result.metadatas[0]
|
||||
documents = result.documents[0]
|
||||
|
||||
results = []
|
||||
for idx in range(len(ids)):
|
||||
results.append(
|
||||
Document(
|
||||
metadata=metadatas[idx],
|
||||
page_content=documents[idx],
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def query_doc(
|
||||
collection_name: str,
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
):
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
collection_name=collection_name,
|
||||
vectors=[embedding_function(query)],
|
||||
limit=k,
|
||||
)
|
||||
|
||||
log.info(f"query_doc:result {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
|
||||
def query_doc_with_hybrid_search(
|
||||
collection_name: str,
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
reranking_function,
|
||||
r: float,
|
||||
) -> dict:
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||
|
||||
bm25_retriever = BM25Retriever.from_texts(
|
||||
texts=result.documents[0],
|
||||
metadatas=result.metadatas[0],
|
||||
)
|
||||
bm25_retriever.k = k
|
||||
|
||||
vector_search_retriever = VectorSearchRetriever(
|
||||
collection_name=collection_name,
|
||||
embedding_function=embedding_function,
|
||||
top_k=k,
|
||||
)
|
||||
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
|
||||
)
|
||||
compressor = RerankCompressor(
|
||||
embedding_function=embedding_function,
|
||||
top_n=k,
|
||||
reranking_function=reranking_function,
|
||||
r_score=r,
|
||||
)
|
||||
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=compressor, base_retriever=ensemble_retriever
|
||||
)
|
||||
|
||||
result = compression_retriever.invoke(query)
|
||||
result = {
|
||||
"distances": [[d.metadata.get("score") for d in result]],
|
||||
"documents": [[d.page_content for d in result]],
|
||||
"metadatas": [[d.metadata for d in result]],
|
||||
}
|
||||
|
||||
log.info(f"query_doc_with_hybrid_search:result {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def merge_and_sort_query_results(
|
||||
query_results: list[dict], k: int, reverse: bool = False
|
||||
) -> list[dict]:
|
||||
# Initialize lists to store combined data
|
||||
combined_distances = []
|
||||
combined_documents = []
|
||||
combined_metadatas = []
|
||||
|
||||
for data in query_results:
|
||||
combined_distances.extend(data["distances"][0])
|
||||
combined_documents.extend(data["documents"][0])
|
||||
combined_metadatas.extend(data["metadatas"][0])
|
||||
|
||||
# Create a list of tuples (distance, document, metadata)
|
||||
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
|
||||
|
||||
# Sort the list based on distances
|
||||
combined.sort(key=lambda x: x[0], reverse=reverse)
|
||||
|
||||
# We don't have anything :-(
|
||||
if not combined:
|
||||
sorted_distances = []
|
||||
sorted_documents = []
|
||||
sorted_metadatas = []
|
||||
else:
|
||||
# Unzip the sorted list
|
||||
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
|
||||
|
||||
# Slicing the lists to include only k elements
|
||||
sorted_distances = list(sorted_distances)[:k]
|
||||
sorted_documents = list(sorted_documents)[:k]
|
||||
sorted_metadatas = list(sorted_metadatas)[:k]
|
||||
|
||||
# Create the output dictionary
|
||||
result = {
|
||||
"distances": [sorted_distances],
|
||||
"documents": [sorted_documents],
|
||||
"metadatas": [sorted_metadatas],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def query_collection(
|
||||
collection_names: list[str],
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
) -> dict:
|
||||
results = []
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
try:
|
||||
result = query_doc(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
|
||||
return merge_and_sort_query_results(results, k=k)
|
||||
|
||||
|
||||
def query_collection_with_hybrid_search(
|
||||
collection_names: list[str],
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
reranking_function,
|
||||
r: float,
|
||||
) -> dict:
|
||||
results = []
|
||||
error = False
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
result = query_doc_with_hybrid_search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
"Error when querying the collection with " f"hybrid_search: {e}"
|
||||
)
|
||||
error = True
|
||||
|
||||
if error:
|
||||
raise Exception(
|
||||
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
|
||||
)
|
||||
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
|
||||
|
||||
def rag_template(template: str, context: str, query: str):
|
||||
count = template.count("[context]")
|
||||
assert "[context]" in template, "RAG template does not contain '[context]'"
|
||||
|
||||
if "<context>" in context and "</context>" in context:
|
||||
log.debug(
|
||||
"WARNING: Potential prompt injection attack: the RAG "
|
||||
"context contains '<context>' and '</context>'. This might be "
|
||||
"nothing, or the user might be trying to hack something."
|
||||
)
|
||||
|
||||
if "[query]" in context:
|
||||
query_placeholder = f"[query-{str(uuid.uuid4())}]"
|
||||
template = template.replace("[query]", query_placeholder)
|
||||
template = template.replace("[context]", context)
|
||||
template = template.replace(query_placeholder, query)
|
||||
else:
|
||||
template = template.replace("[context]", context)
|
||||
template = template.replace("[query]", query)
|
||||
return template
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
if embedding_engine == "ollama":
|
||||
func = lambda query: generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": embedding_model,
|
||||
"prompt": query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif embedding_engine == "openai":
|
||||
func = lambda query: generate_openai_embeddings(
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
key=openai_key,
|
||||
url=openai_url,
|
||||
)
|
||||
|
||||
def generate_multiple(query, f):
|
||||
if isinstance(query, list):
|
||||
if embedding_engine == "openai":
|
||||
embeddings = []
|
||||
for i in range(0, len(query), batch_size):
|
||||
embeddings.extend(f(query[i : i + batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return [f(q) for q in query]
|
||||
else:
|
||||
return f(query)
|
||||
|
||||
return lambda query: generate_multiple(query, func)
|
||||
|
||||
|
||||
def get_rag_context(
|
||||
files,
|
||||
messages,
|
||||
embedding_function,
|
||||
k,
|
||||
reranking_function,
|
||||
r,
|
||||
hybrid_search,
|
||||
):
|
||||
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
|
||||
query = get_last_user_message(messages)
|
||||
|
||||
extracted_collections = []
|
||||
relevant_contexts = []
|
||||
|
||||
for file in files:
|
||||
context = None
|
||||
|
||||
collection_names = (
|
||||
file["collection_names"]
|
||||
if file["type"] == "collection"
|
||||
else [file["collection_name"]] if file["collection_name"] else []
|
||||
)
|
||||
|
||||
collection_names = set(collection_names).difference(extracted_collections)
|
||||
if not collection_names:
|
||||
log.debug(f"skipping {file} as it has already been extracted")
|
||||
continue
|
||||
|
||||
try:
|
||||
context = None
|
||||
if file["type"] == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Error when using hybrid search, using"
|
||||
" non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
if context:
|
||||
relevant_contexts.append({**context, "source": file})
|
||||
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
for context in relevant_contexts:
|
||||
try:
|
||||
if "documents" in context:
|
||||
contexts.append(
|
||||
"\n\n".join(
|
||||
[text for text in context["documents"][0] if text is not None]
|
||||
)
|
||||
)
|
||||
|
||||
if "metadatas" in context:
|
||||
citations.append(
|
||||
{
|
||||
"source": context["source"],
|
||||
"document": context["documents"][0],
|
||||
"metadata": context["metadatas"][0],
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
return contexts, citations
|
||||
|
||||
|
||||
def get_model_path(model: str, update_model: bool = False):
|
||||
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
||||
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
||||
|
||||
local_files_only = not update_model
|
||||
|
||||
snapshot_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
log.debug(f"model: {model}")
|
||||
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
|
||||
|
||||
# Inspiration from upstream sentence_transformers
|
||||
if (
|
||||
os.path.exists(model)
|
||||
or ("\\" in model or model.count("/") > 1)
|
||||
and local_files_only
|
||||
):
|
||||
# If fully qualified path exists, return input, else set repo_id
|
||||
return model
|
||||
elif "/" not in model:
|
||||
# Set valid repo_id for model short-name
|
||||
model = "sentence-transformers" + "/" + model
|
||||
|
||||
snapshot_kwargs["repo_id"] = model
|
||||
|
||||
# Attempt to query the huggingface_hub library to determine the local path and/or to update
|
||||
try:
|
||||
model_repo_path = snapshot_download(**snapshot_kwargs)
|
||||
log.debug(f"model_repo_path: {model_repo_path}")
|
||||
return model_repo_path
|
||||
except Exception as e:
|
||||
log.exception(f"Cannot determine model snapshot path: {e}")
|
||||
return model
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str,
|
||||
text: Union[str, list[str]],
|
||||
key: str,
|
||||
url: str = "https://api.openai.com/v1",
|
||||
):
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{url}/embeddings",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "data" in data:
|
||||
return [elem["embedding"] for elem in data["data"]]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
import operator
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
|
||||
|
||||
class RerankCompressor(BaseDocumentCompressor):
|
||||
embedding_function: Any
|
||||
top_n: int
|
||||
reranking_function: Any
|
||||
r_score: float
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
reranking = self.reranking_function is not None
|
||||
|
||||
if reranking:
|
||||
scores = self.reranking_function.predict(
|
||||
[(query, doc.page_content) for doc in documents]
|
||||
)
|
||||
else:
|
||||
from sentence_transformers import util
|
||||
|
||||
query_embedding = self.embedding_function(query)
|
||||
document_embedding = self.embedding_function(
|
||||
[doc.page_content for doc in documents]
|
||||
)
|
||||
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
||||
|
||||
docs_with_scores = list(zip(documents, scores.tolist()))
|
||||
if self.r_score:
|
||||
docs_with_scores = [
|
||||
(d, s) for d, s in docs_with_scores if s >= self.r_score
|
||||
]
|
||||
|
||||
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
||||
final_results = []
|
||||
for doc, doc_score in result[: self.top_n]:
|
||||
metadata = doc.metadata
|
||||
metadata["score"] = doc_score
|
||||
doc = Document(
|
||||
page_content=doc.page_content,
|
||||
metadata=metadata,
|
||||
)
|
||||
final_results.append(doc)
|
||||
return final_results
|
||||
Reference in New Issue
Block a user