mirror of
https://github.com/open-webui/open-webui.git
synced 2026-02-24 12:11:56 +01:00
chore: format
This commit is contained in:
@@ -622,8 +622,12 @@ class ChatTable:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
# Use subquery to delete chat_messages for shared chats
|
||||
shared_chat_id_subquery = db.query(Chat.id).filter_by(user_id=f"shared-{chat_id}").subquery()
|
||||
db.query(ChatMessage).filter(ChatMessage.chat_id.in_(shared_chat_id_subquery)).delete(synchronize_session=False)
|
||||
shared_chat_id_subquery = (
|
||||
db.query(Chat.id).filter_by(user_id=f"shared-{chat_id}").subquery()
|
||||
)
|
||||
db.query(ChatMessage).filter(
|
||||
ChatMessage.chat_id.in_(shared_chat_id_subquery)
|
||||
).delete(synchronize_session=False)
|
||||
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
|
||||
db.commit()
|
||||
|
||||
@@ -1441,8 +1445,12 @@ class ChatTable:
|
||||
with get_db_context(db) as db:
|
||||
self.delete_shared_chats_by_user_id(user_id, db=db)
|
||||
|
||||
chat_id_subquery = db.query(Chat.id).filter_by(user_id=user_id).subquery()
|
||||
db.query(ChatMessage).filter(ChatMessage.chat_id.in_(chat_id_subquery)).delete(synchronize_session=False)
|
||||
chat_id_subquery = (
|
||||
db.query(Chat.id).filter_by(user_id=user_id).subquery()
|
||||
)
|
||||
db.query(ChatMessage).filter(
|
||||
ChatMessage.chat_id.in_(chat_id_subquery)
|
||||
).delete(synchronize_session=False)
|
||||
db.query(Chat).filter_by(user_id=user_id).delete()
|
||||
db.commit()
|
||||
|
||||
@@ -1455,8 +1463,14 @@ class ChatTable:
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
chat_id_subquery = db.query(Chat.id).filter_by(user_id=user_id, folder_id=folder_id).subquery()
|
||||
db.query(ChatMessage).filter(ChatMessage.chat_id.in_(chat_id_subquery)).delete(synchronize_session=False)
|
||||
chat_id_subquery = (
|
||||
db.query(Chat.id)
|
||||
.filter_by(user_id=user_id, folder_id=folder_id)
|
||||
.subquery()
|
||||
)
|
||||
db.query(ChatMessage).filter(
|
||||
ChatMessage.chat_id.in_(chat_id_subquery)
|
||||
).delete(synchronize_session=False)
|
||||
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
|
||||
db.commit()
|
||||
|
||||
@@ -1491,8 +1505,14 @@ class ChatTable:
|
||||
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
|
||||
|
||||
# Use subquery to delete chat_messages for shared chats
|
||||
shared_id_subq = db.query(Chat.id).filter(Chat.user_id.in_(shared_chat_ids)).subquery()
|
||||
db.query(ChatMessage).filter(ChatMessage.chat_id.in_(shared_id_subq)).delete(synchronize_session=False)
|
||||
shared_id_subq = (
|
||||
db.query(Chat.id)
|
||||
.filter(Chat.user_id.in_(shared_chat_ids))
|
||||
.subquery()
|
||||
)
|
||||
db.query(ChatMessage).filter(
|
||||
ChatMessage.chat_id.in_(shared_id_subq)
|
||||
).delete(synchronize_session=False)
|
||||
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -363,9 +363,7 @@ class UsersTable:
|
||||
query = db.query(User)
|
||||
if dialect_name == "sqlite":
|
||||
query = query.filter(
|
||||
User.scim.contains(
|
||||
{provider: {"external_id": external_id}}
|
||||
)
|
||||
User.scim.contains({provider: {"external_id": external_id}})
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
query = query.filter(
|
||||
@@ -533,7 +531,12 @@ class UsersTable:
|
||||
self, user_ids: list[str], db: Optional[Session] = None
|
||||
) -> list[UserStatusModel]:
|
||||
with get_db_context(db) as db:
|
||||
users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all()
|
||||
users = (
|
||||
db.query(User)
|
||||
.options(defer(User.profile_image_url))
|
||||
.filter(User.id.in_(user_ids))
|
||||
.all()
|
||||
)
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:
|
||||
|
||||
@@ -567,7 +567,10 @@ async def update_knowledge_access_by_id(
|
||||
form_data.access_grants = [
|
||||
grant
|
||||
for grant in form_data.access_grants
|
||||
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
|
||||
if not (
|
||||
grant.get("principal_type") == "user"
|
||||
and grant.get("principal_id") == "*"
|
||||
)
|
||||
]
|
||||
|
||||
AccessGrants.set_access_grants("knowledge", id, form_data.access_grants, db=db)
|
||||
|
||||
@@ -577,7 +577,10 @@ async def update_model_access_by_id(
|
||||
form_data.access_grants = [
|
||||
grant
|
||||
for grant in form_data.access_grants
|
||||
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
|
||||
if not (
|
||||
grant.get("principal_type") == "user"
|
||||
and grant.get("principal_id") == "*"
|
||||
)
|
||||
]
|
||||
|
||||
AccessGrants.set_access_grants(
|
||||
|
||||
@@ -358,7 +358,10 @@ async def update_note_access_by_id(
|
||||
form_data.access_grants = [
|
||||
grant
|
||||
for grant in form_data.access_grants
|
||||
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
|
||||
if not (
|
||||
grant.get("principal_type") == "user"
|
||||
and grant.get("principal_id") == "*"
|
||||
)
|
||||
]
|
||||
|
||||
AccessGrants.set_access_grants("note", id, form_data.access_grants, db=db)
|
||||
|
||||
@@ -418,8 +418,13 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
async def get_filtered_models(models, user, db=None):
|
||||
# Filter models based on user access control
|
||||
model_ids = [model["model"] for model in models.get("models", [])]
|
||||
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||
model_infos = {
|
||||
model_info.id: model_info
|
||||
for model_info in Models.get_models_by_ids(model_ids, db=db)
|
||||
}
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
|
||||
}
|
||||
|
||||
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
||||
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
|
||||
@@ -1633,8 +1638,13 @@ async def get_openai_models(
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
# Filter models based on user access control
|
||||
model_ids = [model["id"] for model in models]
|
||||
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||
model_infos = {
|
||||
model_info.id: model_info
|
||||
for model_info in Models.get_models_by_ids(model_ids, db=db)
|
||||
}
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
|
||||
}
|
||||
|
||||
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
||||
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
|
||||
@@ -1650,7 +1660,10 @@ async def get_openai_models(
|
||||
for model in models:
|
||||
model_info = model_infos.get(model["id"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or model_info.id in accessible_model_ids:
|
||||
if (
|
||||
user.id == model_info.user_id
|
||||
or model_info.id in accessible_model_ids
|
||||
):
|
||||
filtered_models.append(model)
|
||||
models = filtered_models
|
||||
|
||||
|
||||
@@ -486,7 +486,10 @@ async def update_prompt_access_by_id(
|
||||
form_data.access_grants = [
|
||||
grant
|
||||
for grant in form_data.access_grants
|
||||
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
|
||||
if not (
|
||||
grant.get("principal_type") == "user"
|
||||
and grant.get("principal_id") == "*"
|
||||
)
|
||||
]
|
||||
|
||||
AccessGrants.set_access_grants("prompt", prompt_id, form_data.access_grants, db=db)
|
||||
|
||||
@@ -329,9 +329,7 @@ def get_scim_provider() -> str:
|
||||
return SCIM_AUTH_PROVIDER
|
||||
|
||||
|
||||
def find_user_by_external_id(
|
||||
external_id: str, db=None
|
||||
) -> Optional[UserModel]:
|
||||
def find_user_by_external_id(external_id: str, db=None) -> Optional[UserModel]:
|
||||
"""Find a user by SCIM externalId, falling back to OAuth sub match."""
|
||||
provider = get_scim_provider()
|
||||
user = Users.get_user_by_scim_external_id(provider, external_id, db=db)
|
||||
@@ -652,9 +650,7 @@ async def create_user(
|
||||
# Store externalId in the scim field
|
||||
if user_data.externalId:
|
||||
provider = get_scim_provider()
|
||||
Users.update_user_scim_by_id(
|
||||
user_id, provider, user_data.externalId, db=db
|
||||
)
|
||||
Users.update_user_scim_by_id(user_id, provider, user_data.externalId, db=db)
|
||||
new_user = Users.get_user_by_id(user_id, db=db)
|
||||
|
||||
return user_to_scim(new_user, request, db=db)
|
||||
@@ -711,9 +707,7 @@ async def update_user(
|
||||
# Update externalId in the scim field
|
||||
if user_data.externalId:
|
||||
provider = get_scim_provider()
|
||||
Users.update_user_scim_by_id(
|
||||
user_id, provider, user_data.externalId, db=db
|
||||
)
|
||||
Users.update_user_scim_by_id(user_id, provider, user_data.externalId, db=db)
|
||||
updated_user = Users.get_user_by_id(user_id, db=db)
|
||||
|
||||
return user_to_scim(updated_user, request, db=db)
|
||||
@@ -755,9 +749,7 @@ async def patch_user(
|
||||
update_data["name"] = value
|
||||
elif path == "externalId":
|
||||
provider = get_scim_provider()
|
||||
Users.update_user_scim_by_id(
|
||||
user_id, provider, value, db=db
|
||||
)
|
||||
Users.update_user_scim_by_id(user_id, provider, value, db=db)
|
||||
|
||||
# Update user
|
||||
if update_data:
|
||||
|
||||
@@ -354,7 +354,10 @@ async def update_skill_access_by_id(
|
||||
form_data.access_grants = [
|
||||
grant
|
||||
for grant in form_data.access_grants
|
||||
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
|
||||
if not (
|
||||
grant.get("principal_type") == "user"
|
||||
and grant.get("principal_id") == "*"
|
||||
)
|
||||
]
|
||||
|
||||
AccessGrants.set_access_grants("skill", id, form_data.access_grants, db=db)
|
||||
|
||||
@@ -568,7 +568,10 @@ async def update_tool_access_by_id(
|
||||
form_data.access_grants = [
|
||||
grant
|
||||
for grant in form_data.access_grants
|
||||
if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*")
|
||||
if not (
|
||||
grant.get("principal_type") == "user"
|
||||
and grant.get("principal_id") == "*"
|
||||
)
|
||||
]
|
||||
|
||||
AccessGrants.set_access_grants("tool", id, form_data.access_grants, db=db)
|
||||
|
||||
Reference in New Issue
Block a user