This commit is contained in:
Timothy Jaeryang Baek
2026-05-09 07:52:15 +09:00
parent 85c7373f68
commit 485d689cfd
2 changed files with 87 additions and 0 deletions

View File

@@ -241,6 +241,81 @@ class ChatMessageTable:
messages = result.scalars().all()
return [ChatMessageModel.model_validate(message) for message in messages]
# DB column names that differ from the JSON message keys.
DB_TO_JSON_KEY_MAP = {
'parent_id': 'parentId',
'model_id': 'model',
'status_history': 'statusHistory',
'created_at': 'timestamp',
}
# DB-internal columns excluded from the reconstructed message dict.
EXCLUDED_COLUMNS = frozenset({'id', 'chat_id', 'user_id', 'updated_at'})
async def get_messages_map_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> Optional[dict]:
"""Build a {message_id: message_dict} map from chat_message rows.
Returns the same shape as chat.history.messages so callers
(get_message_list, middleware) work unchanged. Returns None if
no rows exist for the chat (caller should fall back to the
embedded JSON blob for legacy chats).
"""
async with get_async_db_context(db) as db:
result = await db.execute(
select(ChatMessage).filter_by(chat_id=chat_id)
)
rows = result.scalars().all()
if not rows:
return None
# Strip the composite-id prefix ("{chat_id}-") to recover the
# original message_id used as map key.
prefix = f'{chat_id}-'
prefix_len = len(prefix)
col_keys = [c.key for c in ChatMessage.__table__.columns]
messages_map: dict[str, dict] = {}
for row in rows:
msg_id = row.id[prefix_len:] if row.id.startswith(prefix) else row.id
msg: dict = {'id': msg_id}
for key in col_keys:
if key in self.EXCLUDED_COLUMNS:
continue
val = getattr(row, key)
if val is None:
continue
json_key = self.DB_TO_JSON_KEY_MAP.get(key, key)
msg[json_key] = val
# Ensure content always has a value
msg.setdefault('content', '')
# Mirror usage into info.usage for callers that read it there
if 'usage' in msg:
msg['info'] = {'usage': msg['usage']}
messages_map[msg_id] = msg
# Reconstruct childrenIds from parentId links so that the map
# is fully navigable (callers like the frontend rely on this).
for msg_id, msg in messages_map.items():
parent_id = msg.get('parentId')
if parent_id and parent_id in messages_map:
parent = messages_map[parent_id]
children = parent.get('childrenIds')
if children is None:
parent['childrenIds'] = [msg_id]
elif msg_id not in children:
children.append(msg_id)
# Ensure every message has a childrenIds list (leaf nodes get [])
for msg in messages_map.values():
if 'childrenIds' not in msg:
msg['childrenIds'] = []
return messages_map
async def get_messages_by_user_id(
self,
user_id: str,

View File

@@ -460,6 +460,18 @@ class ChatTable:
return row[0] or 'New Chat'
async def get_messages_map_by_chat_id(self, id: str) -> Optional[dict]:
"""Message map for walking history (see ``get_message_list``).
Prefer ``chat_message`` rows to avoid loading the large ``chat``
JSON blob; fall back to embedded history when no rows exist
(legacy chats).
"""
# Fast path: build from normalized chat_message rows.
messages_map = await ChatMessages.get_messages_map_by_chat_id(id)
if messages_map is not None:
return messages_map
# No rows — fall back to the embedded JSON blob for legacy chats.
chat = await self.get_chat_by_id(id)
if chat is None:
return None