feat: retrieval whole document mode

This commit is contained in:
Timothy J. Baek
2024-09-29 22:52:27 +02:00
parent 1d8b3b8c51
commit 6d764ee55e
5 changed files with 112 additions and 67 deletions

View File

@@ -317,58 +317,63 @@ def get_rag_context(
relevant_contexts = []
for file in files:
context = None
collection_names = (
file["collection_names"]
if file["type"] == "collection"
else [file["collection_name"]] if file["collection_name"] else []
)
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
log.debug(f"skipping {file} as it has already been extracted")
continue
try:
if file.get("context") == "full":
context = {
"documents": [[file["content"]]],
"metadatas": [[{"file_id": file["id"], "name": file["name"]}]],
}
else:
context = None
if file["type"] == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
collection_names = (
file["collection_names"]
if file["type"] == "collection"
else [file["collection_name"]] if file["collection_name"] else []
)
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
log.debug(f"skipping {file} as it has already been extracted")
continue
try:
context = None
if file["type"] == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
except Exception as e:
log.exception(e)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
extracted_collections.extend(collection_names)
if context:
relevant_contexts.append({**context, "source": file})
extracted_collections.extend(collection_names)
relevant_contexts.append({**context, "file": file})
contexts = []
citations = []
for context in relevant_contexts:
try:
if "documents" in context:
@@ -381,7 +386,7 @@ def get_rag_context(
if "metadatas" in context:
citations.append(
{
"source": context["source"],
"source": context["file"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}