chore: format

This commit is contained in:
Timothy Jaeryang Baek
2026-02-13 15:00:39 -06:00
parent 79ecbfc757
commit 626d236d13
10 changed files with 81 additions and 35 deletions

View File

@@ -622,8 +622,12 @@ class ChatTable:
try:
with get_db_context(db) as db:
# Use subquery to delete chat_messages for shared chats
shared_chat_id_subquery = db.query(Chat.id).filter_by(user_id=f"shared-{chat_id}").subquery()
db.query(ChatMessage).filter(ChatMessage.chat_id.in_(shared_chat_id_subquery)).delete(synchronize_session=False)
shared_chat_id_subquery = (
db.query(Chat.id).filter_by(user_id=f"shared-{chat_id}").subquery()
)
db.query(ChatMessage).filter(
ChatMessage.chat_id.in_(shared_chat_id_subquery)
).delete(synchronize_session=False)
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
db.commit()
@@ -1441,8 +1445,12 @@ class ChatTable:
with get_db_context(db) as db:
self.delete_shared_chats_by_user_id(user_id, db=db)
chat_id_subquery = db.query(Chat.id).filter_by(user_id=user_id).subquery()
db.query(ChatMessage).filter(ChatMessage.chat_id.in_(chat_id_subquery)).delete(synchronize_session=False)
chat_id_subquery = (
db.query(Chat.id).filter_by(user_id=user_id).subquery()
)
db.query(ChatMessage).filter(
ChatMessage.chat_id.in_(chat_id_subquery)
).delete(synchronize_session=False)
db.query(Chat).filter_by(user_id=user_id).delete()
db.commit()
@@ -1455,8 +1463,14 @@ class ChatTable:
) -> bool:
try:
with get_db_context(db) as db:
chat_id_subquery = db.query(Chat.id).filter_by(user_id=user_id, folder_id=folder_id).subquery()
db.query(ChatMessage).filter(ChatMessage.chat_id.in_(chat_id_subquery)).delete(synchronize_session=False)
chat_id_subquery = (
db.query(Chat.id)
.filter_by(user_id=user_id, folder_id=folder_id)
.subquery()
)
db.query(ChatMessage).filter(
ChatMessage.chat_id.in_(chat_id_subquery)
).delete(synchronize_session=False)
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
db.commit()
@@ -1491,8 +1505,14 @@ class ChatTable:
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
# Use subquery to delete chat_messages for shared chats
shared_id_subq = db.query(Chat.id).filter(Chat.user_id.in_(shared_chat_ids)).subquery()
db.query(ChatMessage).filter(ChatMessage.chat_id.in_(shared_id_subq)).delete(synchronize_session=False)
shared_id_subq = (
db.query(Chat.id)
.filter(Chat.user_id.in_(shared_chat_ids))
.subquery()
)
db.query(ChatMessage).filter(
ChatMessage.chat_id.in_(shared_id_subq)
).delete(synchronize_session=False)
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
db.commit()

View File

@@ -363,9 +363,7 @@ class UsersTable:
query = db.query(User)
if dialect_name == "sqlite":
query = query.filter(
User.scim.contains(
{provider: {"external_id": external_id}}
)
User.scim.contains({provider: {"external_id": external_id}})
)
elif dialect_name == "postgresql":
query = query.filter(
@@ -533,7 +531,12 @@ class UsersTable:
self, user_ids: list[str], db: Optional[Session] = None
) -> list[UserStatusModel]:
with get_db_context(db) as db:
users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all()
users = (
db.query(User)
.options(defer(User.profile_image_url))
.filter(User.id.in_(user_ids))
.all()
)
return [UserModel.model_validate(user) for user in users]
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:

View File

@@ -567,7 +567,10 @@ async def update_knowledge_access_by_id(
form_data.access_grants = [
grant
for grant in form_data.access_grants
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
if not (
grant.get("principal_type") == "user"
and grant.get("principal_id") == "*"
)
]
AccessGrants.set_access_grants("knowledge", id, form_data.access_grants, db=db)

View File

@@ -577,7 +577,10 @@ async def update_model_access_by_id(
form_data.access_grants = [
grant
for grant in form_data.access_grants
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
if not (
grant.get("principal_type") == "user"
and grant.get("principal_id") == "*"
)
]
AccessGrants.set_access_grants(

View File

@@ -358,7 +358,10 @@ async def update_note_access_by_id(
form_data.access_grants = [
grant
for grant in form_data.access_grants
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
if not (
grant.get("principal_type") == "user"
and grant.get("principal_id") == "*"
)
]
AccessGrants.set_access_grants("note", id, form_data.access_grants, db=db)

View File

@@ -418,8 +418,13 @@ async def get_all_models(request: Request, user: UserModel = None):
async def get_filtered_models(models, user, db=None):
# Filter models based on user access control
model_ids = [model["model"] for model in models.get("models", [])]
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
model_infos = {
model_info.id: model_info
for model_info in Models.get_models_by_ids(model_ids, db=db)
}
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
}
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
@@ -1633,8 +1638,13 @@ async def get_openai_models(
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
# Filter models based on user access control
model_ids = [model["id"] for model in models]
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
model_infos = {
model_info.id: model_info
for model_info in Models.get_models_by_ids(model_ids, db=db)
}
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
}
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
@@ -1650,7 +1660,10 @@ async def get_openai_models(
for model in models:
model_info = model_infos.get(model["id"])
if model_info:
if user.id == model_info.user_id or model_info.id in accessible_model_ids:
if (
user.id == model_info.user_id
or model_info.id in accessible_model_ids
):
filtered_models.append(model)
models = filtered_models

View File

@@ -486,7 +486,10 @@ async def update_prompt_access_by_id(
form_data.access_grants = [
grant
for grant in form_data.access_grants
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
if not (
grant.get("principal_type") == "user"
and grant.get("principal_id") == "*"
)
]
AccessGrants.set_access_grants("prompt", prompt_id, form_data.access_grants, db=db)

View File

@@ -329,9 +329,7 @@ def get_scim_provider() -> str:
return SCIM_AUTH_PROVIDER
def find_user_by_external_id(
external_id: str, db=None
) -> Optional[UserModel]:
def find_user_by_external_id(external_id: str, db=None) -> Optional[UserModel]:
"""Find a user by SCIM externalId, falling back to OAuth sub match."""
provider = get_scim_provider()
user = Users.get_user_by_scim_external_id(provider, external_id, db=db)
@@ -652,9 +650,7 @@ async def create_user(
# Store externalId in the scim field
if user_data.externalId:
provider = get_scim_provider()
Users.update_user_scim_by_id(
user_id, provider, user_data.externalId, db=db
)
Users.update_user_scim_by_id(user_id, provider, user_data.externalId, db=db)
new_user = Users.get_user_by_id(user_id, db=db)
return user_to_scim(new_user, request, db=db)
@@ -711,9 +707,7 @@ async def update_user(
# Update externalId in the scim field
if user_data.externalId:
provider = get_scim_provider()
Users.update_user_scim_by_id(
user_id, provider, user_data.externalId, db=db
)
Users.update_user_scim_by_id(user_id, provider, user_data.externalId, db=db)
updated_user = Users.get_user_by_id(user_id, db=db)
return user_to_scim(updated_user, request, db=db)
@@ -755,9 +749,7 @@ async def patch_user(
update_data["name"] = value
elif path == "externalId":
provider = get_scim_provider()
Users.update_user_scim_by_id(
user_id, provider, value, db=db
)
Users.update_user_scim_by_id(user_id, provider, value, db=db)
# Update user
if update_data:

View File

@@ -354,7 +354,10 @@ async def update_skill_access_by_id(
form_data.access_grants = [
grant
for grant in form_data.access_grants
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
if not (
grant.get("principal_type") == "user"
and grant.get("principal_id") == "*"
)
]
AccessGrants.set_access_grants("skill", id, form_data.access_grants, db=db)

View File

@@ -568,7 +568,10 @@ async def update_tool_access_by_id(
form_data.access_grants = [
grant
for grant in form_data.access_grants
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
if not (
grant.get("principal_type") == "user"
and grant.get("principal_id") == "*"
)
]
AccessGrants.set_access_grants("tool", id, form_data.access_grants, db=db)