This commit is contained in:
Timothy Jaeryang Baek
2026-02-16 00:41:36 -06:00
parent 33308022f0
commit c748c3ede7
3 changed files with 106 additions and 56 deletions

View File

@@ -431,22 +431,31 @@ class ChatTable:
def update_chat_tags_by_id(
self, id: str, tags: list[str], user
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
with get_db_context() as db:
chat = db.get(Chat, id)
if chat is None:
return None
self.delete_all_tags_by_id_and_user_id(id, user.id)
old_tags = chat.meta.get("tags", [])
new_tags = [t for t in tags if t.replace(" ", "_").lower() != "none"]
new_tag_ids = [t.replace(" ", "_").lower() for t in new_tags]
for tag in chat.meta.get("tags", []):
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
# Single meta update
chat.meta = {**chat.meta, "tags": new_tag_ids}
db.commit()
db.refresh(chat)
for tag_name in tags:
if tag_name.lower() == "none":
continue
# Batch-create any missing tag rows
Tags.ensure_tags_exist(new_tags, user.id, db=db)
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
return self.get_chat_by_id(id)
# Clean up orphaned old tags in one query
removed = set(old_tags) - set(new_tag_ids)
if removed:
self.delete_orphan_tags_for_user(
list(removed), user.id, db=db
)
return ChatModel.model_validate(chat)
def get_chat_title_by_id(self, id: str) -> Optional[str]:
chat = self.get_chat_by_id(id)
@@ -1267,8 +1276,8 @@ class ChatTable:
) -> list[TagModel]:
with get_db_context(db) as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
tag_ids = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=db)
def get_chat_list_by_user_id_and_tag_name(
self,
@@ -1309,20 +1318,16 @@ class ChatTable:
def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None
) -> Optional[ChatModel]:
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
if tag is None:
tag = Tags.insert_new_tag(tag_name, user_id)
tag_id = tag_name.replace(" ", "_").lower()
Tags.ensure_tags_exist([tag_name], user_id, db=db)
try:
with get_db_context(db) as db:
chat = db.get(Chat, id)
tag_id = tag.id
if tag_id not in chat.meta.get("tags", []):
chat.meta = {
**chat.meta,
"tags": list(set(chat.meta.get("tags", []) + [tag_id])),
}
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
@@ -1332,40 +1337,55 @@ class ChatTable:
def count_chats_by_tag_name_and_user_id(
self, tag_name: str, user_id: str, db: Optional[Session] = None
) -> int:
with get_db_context(db) as db: # Assuming `get_db()` returns a session object
with get_db_context(db) as db:
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Normalize the tag_name for consistency
tag_id = tag_name.replace(" ", "_").lower()
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Get the count of matching records
count = query.count()
return query.count()
# Debugging output for inspection
log.info(f"Count of chats for tag '{tag_name}': {count}")
def delete_orphan_tags_for_user(
self,
tag_ids: list[str],
user_id: str,
threshold: int = 0,
db: Optional[Session] = None,
) -> None:
"""Delete tag rows from *tag_ids* that appear in at most *threshold*
non-archived chats for *user_id*. One query to find orphans, one to
delete them.
return count
Use threshold=0 after a tag is already removed from a chat's meta.
Use threshold=1 when the chat itself is about to be deleted (the
referencing chat still exists at query time).
"""
if not tag_ids:
return
with get_db_context(db) as db:
orphans = []
for tag_id in tag_ids:
count = self.count_chats_by_tag_name_and_user_id(
tag_id, user_id, db=db
)
if count <= threshold:
orphans.append(tag_id)
Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db)
def count_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str, db: Optional[Session] = None

View File

@@ -115,5 +115,45 @@ class TagTable:
log.error(f"delete_tag: {e}")
return False
def delete_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str, db: Optional[Session] = None
) -> bool:
"""Delete all tags whose id is in *ids* for the given user, in one query."""
if not ids:
return True
try:
with get_db_context(db) as db:
db.query(Tag).filter(
Tag.id.in_(ids), Tag.user_id == user_id
).delete(synchronize_session=False)
db.commit()
return True
except Exception as e:
log.error(f"delete_tags_by_ids: {e}")
return False
def ensure_tags_exist(
self, names: list[str], user_id: str, db: Optional[Session] = None
) -> None:
"""Create tag rows for any *names* that don't already exist for *user_id*."""
if not names:
return
ids = [n.replace(" ", "_").lower() for n in names]
with get_db_context(db) as db:
existing = {
t.id
for t in db.query(Tag.id)
.filter(Tag.id.in_(ids), Tag.user_id == user_id)
.all()
}
new_tags = [
Tag(id=tag_id, name=name, user_id=user_id)
for tag_id, name in zip(ids, names)
if tag_id not in existing
]
if new_tags:
db.add_all(new_tags)
db.commit()
Tags = TagTable()

View File

@@ -1131,9 +1131,9 @@ async def delete_chat_by_id(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
Chats.delete_orphan_tags_for_user(
chat.meta.get("tags", []), user.id, threshold=1, db=db
)
result = Chats.delete_chat_by_id(id, db=db)
@@ -1153,9 +1153,9 @@ async def delete_chat_by_id(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
Chats.delete_orphan_tags_for_user(
chat.meta.get("tags", []), user.id, threshold=1, db=db
)
result = Chats.delete_chat_by_id_and_user_id(id, user.id, db=db)
return result
@@ -1317,21 +1317,13 @@ async def archive_chat_by_id(
if chat:
chat = Chats.toggle_chat_archive_by_id(id, db=db)
# Delete tags if chat is archived
tag_ids = chat.meta.get("tags", [])
if chat.archived:
for tag_id in chat.meta.get("tags", []):
if (
Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id, db=db)
== 0
):
log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db)
# Archived chats are excluded from count — clean up orphans
Chats.delete_orphan_tags_for_user(tag_ids, user.id, db=db)
else:
for tag_id in chat.meta.get("tags", []):
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db)
if tag is None:
log.debug(f"inserting tag: {tag_id}")
tag = Tags.insert_new_tag(tag_id, user.id, db=db)
# Unarchived — ensure tag rows exist
Tags.ensure_tags_exist(tag_ids, user.id, db=db)
return ChatResponse(**chat.model_dump())
else:
@@ -1537,11 +1529,9 @@ async def delete_all_tags_by_id(
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
old_tags = chat.meta.get("tags", [])
Chats.delete_all_tags_by_id_and_user_id(id, user.id, db=db)
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
Chats.delete_orphan_tags_for_user(old_tags, user.id, db=db)
return True
else: