fix: multi-user tags issue

This commit is contained in:
Timothy J. Baek
2024-10-13 01:00:38 -07:00
parent 5273dc4535
commit 8ae605ec4b
3 changed files with 84 additions and 9 deletions

View File

@@ -8,7 +8,7 @@ from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, JSON
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -19,11 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Tag(Base):
__tablename__ = "tag"
id = Column(String, primary_key=True)
id = Column(String)
name = Column(String)
user_id = Column(String)
meta = Column(JSON, nullable=True)
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
class TagModel(BaseModel):
id: str
@@ -57,7 +60,8 @@ class TagTable:
return TagModel.model_validate(result)
else:
return None
except Exception:
except Exception as e:
print(e)
return None
def get_tag_by_name_and_user_id(
@@ -78,11 +82,15 @@ class TagTable:
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
]
def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]:
def get_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str
) -> list[TagModel]:
with get_db() as db:
return [
TagModel.model_validate(tag)
for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all())
for tag in (
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
)
]
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:

View File

@@ -465,7 +465,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -494,7 +494,7 @@ async def add_tag_by_id_and_tag_name(
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -519,7 +519,7 @@ async def delete_tag_by_id_and_tag_name(
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -543,7 +543,7 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND