mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
enh: RAG full context mode
This commit is contained in:
@@ -84,6 +84,19 @@ def query_doc(
|
||||
raise e
|
||||
|
||||
|
||||
def get_doc(collection_name: str, user: UserModel = None):
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||
|
||||
if result:
|
||||
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
|
||||
def query_doc_with_hybrid_search(
|
||||
collection_name: str,
|
||||
query: str,
|
||||
@@ -137,6 +150,24 @@ def query_doc_with_hybrid_search(
|
||||
raise e
|
||||
|
||||
|
||||
def merge_get_results(get_results: list[dict]) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined_documents = []
|
||||
combined_metadatas = []
|
||||
|
||||
for data in get_results:
|
||||
combined_documents.extend(data["documents"][0])
|
||||
combined_metadatas.extend(data["metadatas"][0])
|
||||
|
||||
# Create the output dictionary
|
||||
result = {
|
||||
"documents": [combined_documents],
|
||||
"metadatas": [combined_metadatas],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def merge_and_sort_query_results(
|
||||
query_results: list[dict], k: int, reverse: bool = False
|
||||
) -> list[dict]:
|
||||
@@ -194,6 +225,23 @@ def merge_and_sort_query_results(
|
||||
return result
|
||||
|
||||
|
||||
def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
||||
results = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
try:
|
||||
result = get_doc(collection_name=collection_name)
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
|
||||
return merge_get_results(results)
|
||||
|
||||
|
||||
def query_collection(
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
@@ -311,8 +359,11 @@ def get_sources_from_files(
|
||||
reranking_function,
|
||||
r,
|
||||
hybrid_search,
|
||||
full_context=False,
|
||||
):
|
||||
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
|
||||
log.debug(
|
||||
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
|
||||
)
|
||||
|
||||
extracted_collections = []
|
||||
relevant_contexts = []
|
||||
@@ -350,36 +401,45 @@ def get_sources_from_files(
|
||||
log.debug(f"skipping {file} as it has already been extracted")
|
||||
continue
|
||||
|
||||
try:
|
||||
context = None
|
||||
if file.get("type") == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
if full_context:
|
||||
try:
|
||||
context = get_all_items_from_collections(collection_names)
|
||||
|
||||
print("context", context)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
else:
|
||||
try:
|
||||
context = None
|
||||
if file.get("type") == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Error when using hybrid search, using"
|
||||
" non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Error when using hybrid search, using"
|
||||
" non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user