enh: bypass embedding and retrieval

This commit is contained in:
Timothy Jaeryang Baek
2025-02-26 15:42:19 -08:00
parent 1c2e36f1b7
commit 57010901e6
10 changed files with 486 additions and 370 deletions

View File

@@ -351,24 +351,25 @@ async def chat_web_search_handler(
all_results.append(results)
files = form_data.get("files", [])
if request.app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT:
files.append(
{
"docs": results.get("docs", []),
"name": searchQuery,
"type": "web_search_docs",
"urls": results["filenames"],
}
)
else:
if results.get("collection_name"):
files.append(
{
"collection_name": results["collection_name"],
"name": searchQuery,
"type": "web_search_results",
"type": "web_search",
"urls": results["filenames"],
}
)
elif results.get("docs"):
files.append(
{
"docs": results.get("docs", []),
"name": searchQuery,
"type": "web_search",
"urls": results["filenames"],
}
)
form_data["files"] = files
except Exception as e:
log.exception(e)
@@ -518,6 +519,7 @@ async def chat_completion_files_handler(
sources = []
if files := body.get("metadata", {}).get("files", None):
queries = []
try:
queries_response = await generate_queries(
request,
@@ -543,8 +545,8 @@ async def chat_completion_files_handler(
queries_response = {"queries": [queries_response]}
queries = queries_response.get("queries", [])
except Exception as e:
queries = []
except:
pass
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
@@ -556,6 +558,7 @@ async def chat_completion_files_handler(
sources = await loop.run_in_executor(
executor,
lambda: get_sources_from_files(
request=request,
files=files,
queries=queries,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
@@ -738,6 +741,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None)
# Remove files duplicates
if files:
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
@@ -795,8 +799,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
if len(sources) > 0:
context_string = ""
for source_idx, source in enumerate(sources):
source_id = source.get("source", {}).get("name", "")
if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]):
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
@@ -1913,7 +1915,9 @@ async def process_chat_response(
)
log.info(f"content_blocks={content_blocks}")
log.info(f"serialize_content_blocks={serialize_content_blocks(content_blocks)}")
log.info(
f"serialize_content_blocks={serialize_content_blocks(content_blocks)}"
)
try:
res = await generate_chat_completion(