This commit is contained in:
Timothy J. Baek
2024-10-11 00:00:13 -07:00
parent 9658c2559a
commit ba2df1c33a
3 changed files with 55 additions and 23 deletions

View File

@@ -393,28 +393,56 @@ class ChatTable:
if not include_archived:
query = query.filter(Chat.archived == False)
# Fetch all potentially relevant chats
all_chats = query.all()
query = query.order_by(Chat.updated_at.desc())
# Filter chats using Python
filtered_chats = []
for chat in all_chats:
# Check chat title
title_matches = search_text in chat.title.lower()
# Check if the database dialect is either 'sqlite' or 'postgresql'
dialect_name = db.bind.dialect.name
if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching
query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_each(Chat.chat, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search
query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Check chat content in chat JSON
content_matches = any(
search_text in message.get("content", "").lower()
for message in chat.chat.get("messages", [])
if "content" in message
)
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
if title_matches or content_matches:
filtered_chats.append(chat)
# Implementing pagination manually
paginated_chats = filtered_chats[skip : skip + limit]
return [ChatModel.model_validate(chat) for chat in paginated_chats]
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
with get_db() as db: