mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 20:07:49 +01:00
fix: multi-user tags issue
This commit is contained in:
@@ -8,7 +8,7 @@ from open_webui.apps.webui.internal.db import Base, get_db
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, JSON
|
||||
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -19,11 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
####################
|
||||
class Tag(Base):
|
||||
__tablename__ = "tag"
|
||||
id = Column(String, primary_key=True)
|
||||
id = Column(String)
|
||||
name = Column(String)
|
||||
user_id = Column(String)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
|
||||
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
|
||||
|
||||
|
||||
class TagModel(BaseModel):
|
||||
id: str
|
||||
@@ -57,7 +60,8 @@ class TagTable:
|
||||
return TagModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get_tag_by_name_and_user_id(
|
||||
@@ -78,11 +82,15 @@ class TagTable:
|
||||
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
|
||||
]
|
||||
|
||||
def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]:
|
||||
def get_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str
|
||||
) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all())
|
||||
for tag in (
|
||||
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
|
||||
)
|
||||
]
|
||||
|
||||
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
|
||||
|
||||
@@ -465,7 +465,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@@ -494,7 +494,7 @@ async def add_tag_by_id_and_tag_name(
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
@@ -519,7 +519,7 @@ async def delete_tag_by_id_and_tag_name(
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@@ -543,7 +543,7 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
||||
Reference in New Issue
Block a user