refac: rag

This commit is contained in:
Timothy J. Baek
2024-06-09 03:01:25 -07:00
parent 277fc3feac
commit f2b9a5f5bf
3 changed files with 46 additions and 46 deletions

View File

@@ -20,7 +20,7 @@ from langchain.retrievers import (
from typing import Optional
from utils.misc import get_last_user_message, add_or_update_system_message
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
@@ -247,31 +247,7 @@ def rag_messages(
hybrid_search,
):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_message_idx = i
break
user_message = messages[last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
query = get_last_user_message(messages)
extracted_collections = []
relevant_contexts = []
@@ -349,24 +325,7 @@ def rag_messages(
)
log.debug(f"ra_content: {ra_content}")
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
messages[last_user_message_idx] = new_user_message
messages = add_or_update_system_message(ra_content, messages)
return messages, citations