diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 7ae9f7a38b..2553e4306d 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -431,22 +431,31 @@ class ChatTable: def update_chat_tags_by_id( self, id: str, tags: list[str], user ) -> Optional[ChatModel]: - chat = self.get_chat_by_id(id) - if chat is None: - return None + with get_db_context() as db: + chat = db.get(Chat, id) + if chat is None: + return None - self.delete_all_tags_by_id_and_user_id(id, user.id) + old_tags = chat.meta.get("tags", []) + new_tags = [t for t in tags if t.replace(" ", "_").lower() != "none"] + new_tag_ids = [t.replace(" ", "_").lower() for t in new_tags] - for tag in chat.meta.get("tags", []): - if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + # Single meta update + chat.meta = {**chat.meta, "tags": new_tag_ids} + db.commit() + db.refresh(chat) - for tag_name in tags: - if tag_name.lower() == "none": - continue + # Batch-create any missing tag rows + Tags.ensure_tags_exist(new_tags, user.id, db=db) - self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name) - return self.get_chat_by_id(id) + # Clean up orphaned old tags in one query + removed = set(old_tags) - set(new_tag_ids) + if removed: + self.delete_orphan_tags_for_user( + list(removed), user.id, db=db + ) + + return ChatModel.model_validate(chat) def get_chat_title_by_id(self, id: str) -> Optional[str]: chat = self.get_chat_by_id(id) @@ -1267,8 +1276,8 @@ class ChatTable: ) -> list[TagModel]: with get_db_context(db) as db: chat = db.get(Chat, id) - tags = chat.meta.get("tags", []) - return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags] + tag_ids = chat.meta.get("tags", []) + return Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=db) def get_chat_list_by_user_id_and_tag_name( self, @@ -1309,20 +1318,16 @@ class ChatTable: def add_chat_tag_by_id_and_user_id_and_tag_name( self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None ) -> Optional[ChatModel]: - tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id) - if tag is None: - tag = Tags.insert_new_tag(tag_name, user_id) + tag_id = tag_name.replace(" ", "_").lower() + Tags.ensure_tags_exist([tag_name], user_id, db=db) try: with get_db_context(db) as db: chat = db.get(Chat, id) - - tag_id = tag.id if tag_id not in chat.meta.get("tags", []): chat.meta = { **chat.meta, "tags": list(set(chat.meta.get("tags", []) + [tag_id])), } - db.commit() db.refresh(chat) return ChatModel.model_validate(chat) @@ -1332,40 +1337,55 @@ class ChatTable: def count_chats_by_tag_name_and_user_id( self, tag_name: str, user_id: str, db: Optional[Session] = None ) -> int: - with get_db_context(db) as db: # Assuming `get_db()` returns a session object + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id, archived=False) - - # Normalize the tag_name for consistency tag_id = tag_name.replace(" ", "_").lower() if db.bind.dialect.name == "sqlite": - # SQLite JSON1 support for querying the tags inside the `meta` JSON field query = query.filter( text( - f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" + "EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" ) ).params(tag_id=tag_id) - elif db.bind.dialect.name == "postgresql": - # PostgreSQL JSONB support for querying the tags inside the `meta` JSON field query = query.filter( text( "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" ) ).params(tag_id=tag_id) - else: raise NotImplementedError( f"Unsupported dialect: {db.bind.dialect.name}" ) - # Get the count of matching records - count = query.count() + return query.count() - # Debugging output for inspection - log.info(f"Count of chats for tag '{tag_name}': {count}") + def delete_orphan_tags_for_user( + self, + tag_ids: list[str], + user_id: str, + threshold: int = 0, + db: Optional[Session] = None, + ) -> None: + """Delete tag rows from *tag_ids* that appear in at most *threshold* + non-archived chats for *user_id*. One query to find orphans, one to + delete them. - return count + Use threshold=0 after a tag is already removed from a chat's meta. + Use threshold=1 when the chat itself is about to be deleted (the + referencing chat still exists at query time). + """ + if not tag_ids: + return + with get_db_context(db) as db: + orphans = [] + for tag_id in tag_ids: + count = self.count_chats_by_tag_name_and_user_id( + tag_id, user_id, db=db + ) + if count <= threshold: + orphans.append(tag_id) + Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db) def count_chats_by_folder_id_and_user_id( self, folder_id: str, user_id: str, db: Optional[Session] = None diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index 64cb559547..ef6ad1ded7 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -115,5 +115,45 @@ class TagTable: log.error(f"delete_tag: {e}") return False + def delete_tags_by_ids_and_user_id( + self, ids: list[str], user_id: str, db: Optional[Session] = None + ) -> bool: + """Delete all tags whose id is in *ids* for the given user, in one query.""" + if not ids: + return True + try: + with get_db_context(db) as db: + db.query(Tag).filter( + Tag.id.in_(ids), Tag.user_id == user_id + ).delete(synchronize_session=False) + db.commit() + return True + except Exception as e: + log.error(f"delete_tags_by_ids: {e}") + return False + + def ensure_tags_exist( + self, names: list[str], user_id: str, db: Optional[Session] = None + ) -> None: + """Create tag rows for any *names* that don't already exist for *user_id*.""" + if not names: + return + ids = [n.replace(" ", "_").lower() for n in names] + with get_db_context(db) as db: + existing = { + t.id + for t in db.query(Tag.id) + .filter(Tag.id.in_(ids), Tag.user_id == user_id) + .all() + } + new_tags = [ + Tag(id=tag_id, name=name, user_id=user_id) + for tag_id, name in zip(ids, names) + if tag_id not in existing + ] + if new_tags: + db.add_all(new_tags) + db.commit() + Tags = TagTable() diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index e03cdc7ba9..69e47123f0 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -1131,9 +1131,9 @@ async def delete_chat_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1: - Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db) + Chats.delete_orphan_tags_for_user( + chat.meta.get("tags", []), user.id, threshold=1, db=db + ) result = Chats.delete_chat_by_id(id, db=db) @@ -1153,9 +1153,9 @@ async def delete_chat_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1: - Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db) + Chats.delete_orphan_tags_for_user( + chat.meta.get("tags", []), user.id, threshold=1, db=db + ) result = Chats.delete_chat_by_id_and_user_id(id, user.id, db=db) return result @@ -1317,21 +1317,13 @@ async def archive_chat_by_id( if chat: chat = Chats.toggle_chat_archive_by_id(id, db=db) - # Delete tags if chat is archived + tag_ids = chat.meta.get("tags", []) if chat.archived: - for tag_id in chat.meta.get("tags", []): - if ( - Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id, db=db) - == 0 - ): - log.debug(f"deleting tag: {tag_id}") - Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db) + # Archived chats are excluded from count — clean up orphans + Chats.delete_orphan_tags_for_user(tag_ids, user.id, db=db) else: - for tag_id in chat.meta.get("tags", []): - tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db) - if tag is None: - log.debug(f"inserting tag: {tag_id}") - tag = Tags.insert_new_tag(tag_id, user.id, db=db) + # Unarchived — ensure tag rows exist + Tags.ensure_tags_exist(tag_ids, user.id, db=db) return ChatResponse(**chat.model_dump()) else: @@ -1537,11 +1529,9 @@ async def delete_all_tags_by_id( ): chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: + old_tags = chat.meta.get("tags", []) Chats.delete_all_tags_by_id_and_user_id(id, user.id, db=db) - - for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 0: - Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db) + Chats.delete_orphan_tags_for_user(old_tags, user.id, db=db) return True else: