enh: RAG full context mode

This commit is contained in:
Timothy Jaeryang Baek
2025-02-18 21:14:58 -08:00
parent a6a7c548d5
commit 81715f6553
6 changed files with 127 additions and 32 deletions

View File

@@ -84,6 +84,19 @@ def query_doc(
raise e
def get_doc(collection_name: str, user: UserModel = None):
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
print(e)
raise e
def query_doc_with_hybrid_search(
collection_name: str,
query: str,
@@ -137,6 +150,24 @@ def query_doc_with_hybrid_search(
raise e
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_documents = []
combined_metadatas = []
for data in get_results:
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
# Create the output dictionary
result = {
"documents": [combined_documents],
"metadatas": [combined_metadatas],
}
return result
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> list[dict]:
@@ -194,6 +225,23 @@ def merge_and_sort_query_results(
return result
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
def query_collection(
collection_names: list[str],
queries: list[str],
@@ -311,8 +359,11 @@ def get_sources_from_files(
reranking_function,
r,
hybrid_search,
full_context=False,
):
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
)
extracted_collections = []
relevant_contexts = []
@@ -350,36 +401,45 @@ def get_sources_from_files(
log.debug(f"skipping {file} as it has already been extracted")
continue
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
if full_context:
try:
context = get_all_items_from_collections(collection_names)
print("context", context)
except Exception as e:
log.exception(e)
else:
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
except Exception as e:
log.exception(e)
extracted_collections.extend(collection_names)