mirror of
https://github.com/open-webui/open-webui.git
synced 2026-02-24 04:00:31 +01:00
refac
This commit is contained in:
@@ -431,22 +431,31 @@ class ChatTable:
|
||||
def update_chat_tags_by_id(
|
||||
self, id: str, tags: list[str], user
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
with get_db_context() as db:
|
||||
chat = db.get(Chat, id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
self.delete_all_tags_by_id_and_user_id(id, user.id)
|
||||
old_tags = chat.meta.get("tags", [])
|
||||
new_tags = [t for t in tags if t.replace(" ", "_").lower() != "none"]
|
||||
new_tag_ids = [t.replace(" ", "_").lower() for t in new_tags]
|
||||
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
# Single meta update
|
||||
chat.meta = {**chat.meta, "tags": new_tag_ids}
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
|
||||
for tag_name in tags:
|
||||
if tag_name.lower() == "none":
|
||||
continue
|
||||
# Batch-create any missing tag rows
|
||||
Tags.ensure_tags_exist(new_tags, user.id, db=db)
|
||||
|
||||
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
|
||||
return self.get_chat_by_id(id)
|
||||
# Clean up orphaned old tags in one query
|
||||
removed = set(old_tags) - set(new_tag_ids)
|
||||
if removed:
|
||||
self.delete_orphan_tags_for_user(
|
||||
list(removed), user.id, db=db
|
||||
)
|
||||
|
||||
return ChatModel.model_validate(chat)
|
||||
|
||||
def get_chat_title_by_id(self, id: str) -> Optional[str]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
@@ -1267,8 +1276,8 @@ class ChatTable:
|
||||
) -> list[TagModel]:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
|
||||
tag_ids = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=db)
|
||||
|
||||
def get_chat_list_by_user_id_and_tag_name(
|
||||
self,
|
||||
@@ -1309,20 +1318,16 @@ class ChatTable:
|
||||
def add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None
|
||||
) -> Optional[ChatModel]:
|
||||
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
|
||||
if tag is None:
|
||||
tag = Tags.insert_new_tag(tag_name, user_id)
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
Tags.ensure_tags_exist([tag_name], user_id, db=db)
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, id)
|
||||
|
||||
tag_id = tag.id
|
||||
if tag_id not in chat.meta.get("tags", []):
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
"tags": list(set(chat.meta.get("tags", []) + [tag_id])),
|
||||
}
|
||||
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
@@ -1332,40 +1337,55 @@ class ChatTable:
|
||||
def count_chats_by_tag_name_and_user_id(
|
||||
self, tag_name: str, user_id: str, db: Optional[Session] = None
|
||||
) -> int:
|
||||
with get_db_context(db) as db: # Assuming `get_db()` returns a session object
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
|
||||
|
||||
# Normalize the tag_name for consistency
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
|
||||
"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
|
||||
elif db.bind.dialect.name == "postgresql":
|
||||
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
|
||||
# Get the count of matching records
|
||||
count = query.count()
|
||||
return query.count()
|
||||
|
||||
# Debugging output for inspection
|
||||
log.info(f"Count of chats for tag '{tag_name}': {count}")
|
||||
def delete_orphan_tags_for_user(
|
||||
self,
|
||||
tag_ids: list[str],
|
||||
user_id: str,
|
||||
threshold: int = 0,
|
||||
db: Optional[Session] = None,
|
||||
) -> None:
|
||||
"""Delete tag rows from *tag_ids* that appear in at most *threshold*
|
||||
non-archived chats for *user_id*. One query to find orphans, one to
|
||||
delete them.
|
||||
|
||||
return count
|
||||
Use threshold=0 after a tag is already removed from a chat's meta.
|
||||
Use threshold=1 when the chat itself is about to be deleted (the
|
||||
referencing chat still exists at query time).
|
||||
"""
|
||||
if not tag_ids:
|
||||
return
|
||||
with get_db_context(db) as db:
|
||||
orphans = []
|
||||
for tag_id in tag_ids:
|
||||
count = self.count_chats_by_tag_name_and_user_id(
|
||||
tag_id, user_id, db=db
|
||||
)
|
||||
if count <= threshold:
|
||||
orphans.append(tag_id)
|
||||
Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db)
|
||||
|
||||
def count_chats_by_folder_id_and_user_id(
|
||||
self, folder_id: str, user_id: str, db: Optional[Session] = None
|
||||
|
||||
@@ -115,5 +115,45 @@ class TagTable:
|
||||
log.error(f"delete_tag: {e}")
|
||||
return False
|
||||
|
||||
def delete_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
"""Delete all tags whose id is in *ids* for the given user, in one query."""
|
||||
if not ids:
|
||||
return True
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Tag).filter(
|
||||
Tag.id.in_(ids), Tag.user_id == user_id
|
||||
).delete(synchronize_session=False)
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"delete_tags_by_ids: {e}")
|
||||
return False
|
||||
|
||||
def ensure_tags_exist(
|
||||
self, names: list[str], user_id: str, db: Optional[Session] = None
|
||||
) -> None:
|
||||
"""Create tag rows for any *names* that don't already exist for *user_id*."""
|
||||
if not names:
|
||||
return
|
||||
ids = [n.replace(" ", "_").lower() for n in names]
|
||||
with get_db_context(db) as db:
|
||||
existing = {
|
||||
t.id
|
||||
for t in db.query(Tag.id)
|
||||
.filter(Tag.id.in_(ids), Tag.user_id == user_id)
|
||||
.all()
|
||||
}
|
||||
new_tags = [
|
||||
Tag(id=tag_id, name=name, user_id=user_id)
|
||||
for tag_id, name in zip(ids, names)
|
||||
if tag_id not in existing
|
||||
]
|
||||
if new_tags:
|
||||
db.add_all(new_tags)
|
||||
db.commit()
|
||||
|
||||
|
||||
Tags = TagTable()
|
||||
|
||||
@@ -1131,9 +1131,9 @@ async def delete_chat_by_id(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
|
||||
Chats.delete_orphan_tags_for_user(
|
||||
chat.meta.get("tags", []), user.id, threshold=1, db=db
|
||||
)
|
||||
|
||||
result = Chats.delete_chat_by_id(id, db=db)
|
||||
|
||||
@@ -1153,9 +1153,9 @@ async def delete_chat_by_id(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
|
||||
Chats.delete_orphan_tags_for_user(
|
||||
chat.meta.get("tags", []), user.id, threshold=1, db=db
|
||||
)
|
||||
|
||||
result = Chats.delete_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
return result
|
||||
@@ -1317,21 +1317,13 @@ async def archive_chat_by_id(
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_archive_by_id(id, db=db)
|
||||
|
||||
# Delete tags if chat is archived
|
||||
tag_ids = chat.meta.get("tags", [])
|
||||
if chat.archived:
|
||||
for tag_id in chat.meta.get("tags", []):
|
||||
if (
|
||||
Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id, db=db)
|
||||
== 0
|
||||
):
|
||||
log.debug(f"deleting tag: {tag_id}")
|
||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db)
|
||||
# Archived chats are excluded from count — clean up orphans
|
||||
Chats.delete_orphan_tags_for_user(tag_ids, user.id, db=db)
|
||||
else:
|
||||
for tag_id in chat.meta.get("tags", []):
|
||||
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db)
|
||||
if tag is None:
|
||||
log.debug(f"inserting tag: {tag_id}")
|
||||
tag = Tags.insert_new_tag(tag_id, user.id, db=db)
|
||||
# Unarchived — ensure tag rows exist
|
||||
Tags.ensure_tags_exist(tag_ids, user.id, db=db)
|
||||
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
@@ -1537,11 +1529,9 @@ async def delete_all_tags_by_id(
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
old_tags = chat.meta.get("tags", [])
|
||||
Chats.delete_all_tags_by_id_and_user_id(id, user.id, db=db)
|
||||
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
|
||||
Chats.delete_orphan_tags_for_user(old_tags, user.id, db=db)
|
||||
|
||||
return True
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user