diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 125d6cc3cd..7ae9f7a38b 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -622,8 +622,12 @@ class ChatTable: try: with get_db_context(db) as db: # Use subquery to delete chat_messages for shared chats - shared_chat_id_subquery = db.query(Chat.id).filter_by(user_id=f"shared-{chat_id}").subquery() - db.query(ChatMessage).filter(ChatMessage.chat_id.in_(shared_chat_id_subquery)).delete(synchronize_session=False) + shared_chat_id_subquery = ( + db.query(Chat.id).filter_by(user_id=f"shared-{chat_id}").subquery() + ) + db.query(ChatMessage).filter( + ChatMessage.chat_id.in_(shared_chat_id_subquery) + ).delete(synchronize_session=False) db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() db.commit() @@ -1441,8 +1445,12 @@ class ChatTable: with get_db_context(db) as db: self.delete_shared_chats_by_user_id(user_id, db=db) - chat_id_subquery = db.query(Chat.id).filter_by(user_id=user_id).subquery() - db.query(ChatMessage).filter(ChatMessage.chat_id.in_(chat_id_subquery)).delete(synchronize_session=False) + chat_id_subquery = ( + db.query(Chat.id).filter_by(user_id=user_id).subquery() + ) + db.query(ChatMessage).filter( + ChatMessage.chat_id.in_(chat_id_subquery) + ).delete(synchronize_session=False) db.query(Chat).filter_by(user_id=user_id).delete() db.commit() @@ -1455,8 +1463,14 @@ class ChatTable: ) -> bool: try: with get_db_context(db) as db: - chat_id_subquery = db.query(Chat.id).filter_by(user_id=user_id, folder_id=folder_id).subquery() - db.query(ChatMessage).filter(ChatMessage.chat_id.in_(chat_id_subquery)).delete(synchronize_session=False) + chat_id_subquery = ( + db.query(Chat.id) + .filter_by(user_id=user_id, folder_id=folder_id) + .subquery() + ) + db.query(ChatMessage).filter( + ChatMessage.chat_id.in_(chat_id_subquery) + ).delete(synchronize_session=False) db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete() db.commit() @@ -1491,8 +1505,14 @@ class ChatTable: shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] # Use subquery to delete chat_messages for shared chats - shared_id_subq = db.query(Chat.id).filter(Chat.user_id.in_(shared_chat_ids)).subquery() - db.query(ChatMessage).filter(ChatMessage.chat_id.in_(shared_id_subq)).delete(synchronize_session=False) + shared_id_subq = ( + db.query(Chat.id) + .filter(Chat.user_id.in_(shared_chat_ids)) + .subquery() + ) + db.query(ChatMessage).filter( + ChatMessage.chat_id.in_(shared_id_subq) + ).delete(synchronize_session=False) db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() db.commit() diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 2eb76131ab..e5da9231df 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -363,9 +363,7 @@ class UsersTable: query = db.query(User) if dialect_name == "sqlite": query = query.filter( - User.scim.contains( - {provider: {"external_id": external_id}} - ) + User.scim.contains({provider: {"external_id": external_id}}) ) elif dialect_name == "postgresql": query = query.filter( @@ -533,7 +531,12 @@ class UsersTable: self, user_ids: list[str], db: Optional[Session] = None ) -> list[UserStatusModel]: with get_db_context(db) as db: - users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all() + users = ( + db.query(User) + .options(defer(User.profile_image_url)) + .filter(User.id.in_(user_ids)) + .all() + ) return [UserModel.model_validate(user) for user in users] def get_num_users(self, db: Optional[Session] = None) -> Optional[int]: diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index d620c1745f..1fedab4466 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -567,7 +567,10 @@ async def update_knowledge_access_by_id( form_data.access_grants = [ grant for grant in form_data.access_grants - if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*") + if not ( + grant.get("principal_type") == "user" + and grant.get("principal_id") == "*" + ) ] AccessGrants.set_access_grants("knowledge", id, form_data.access_grants, db=db) diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index aa573dd720..e93d8a729d 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -577,7 +577,10 @@ async def update_model_access_by_id( form_data.access_grants = [ grant for grant in form_data.access_grants - if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*") + if not ( + grant.get("principal_type") == "user" + and grant.get("principal_id") == "*" + ) ] AccessGrants.set_access_grants( diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index cba4c3f4ca..8d1a66c4af 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -358,7 +358,10 @@ async def update_note_access_by_id( form_data.access_grants = [ grant for grant in form_data.access_grants - if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*") + if not ( + grant.get("principal_type") == "user" + and grant.get("principal_id") == "*" + ) ] AccessGrants.set_access_grants("note", id, form_data.access_grants, db=db) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 394735e898..580717987c 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -418,8 +418,13 @@ async def get_all_models(request: Request, user: UserModel = None): async def get_filtered_models(models, user, db=None): # Filter models based on user access control model_ids = [model["model"] for model in models.get("models", [])] - model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)} - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} + model_infos = { + model_info.id: model_info + for model_info in Models.get_models_by_ids(model_ids, db=db) + } + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id, db=db) + } # Batch-fetch accessible resource IDs in a single query instead of N has_access calls accessible_model_ids = AccessGrants.get_accessible_resource_ids( @@ -1633,8 +1638,13 @@ async def get_openai_models( if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: # Filter models based on user access control model_ids = [model["id"] for model in models] - model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)} - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} + model_infos = { + model_info.id: model_info + for model_info in Models.get_models_by_ids(model_ids, db=db) + } + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id, db=db) + } # Batch-fetch accessible resource IDs in a single query instead of N has_access calls accessible_model_ids = AccessGrants.get_accessible_resource_ids( @@ -1650,7 +1660,10 @@ async def get_openai_models( for model in models: model_info = model_infos.get(model["id"]) if model_info: - if user.id == model_info.user_id or model_info.id in accessible_model_ids: + if ( + user.id == model_info.user_id + or model_info.id in accessible_model_ids + ): filtered_models.append(model) models = filtered_models diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 0060ab2b18..2491578959 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -486,7 +486,10 @@ async def update_prompt_access_by_id( form_data.access_grants = [ grant for grant in form_data.access_grants - if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*") + if not ( + grant.get("principal_type") == "user" + and grant.get("principal_id") == "*" + ) ] AccessGrants.set_access_grants("prompt", prompt_id, form_data.access_grants, db=db) diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index 13d3b5bdf2..0c16eb99bd 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -329,9 +329,7 @@ def get_scim_provider() -> str: return SCIM_AUTH_PROVIDER -def find_user_by_external_id( - external_id: str, db=None -) -> Optional[UserModel]: +def find_user_by_external_id(external_id: str, db=None) -> Optional[UserModel]: """Find a user by SCIM externalId, falling back to OAuth sub match.""" provider = get_scim_provider() user = Users.get_user_by_scim_external_id(provider, external_id, db=db) @@ -652,9 +650,7 @@ async def create_user( # Store externalId in the scim field if user_data.externalId: provider = get_scim_provider() - Users.update_user_scim_by_id( - user_id, provider, user_data.externalId, db=db - ) + Users.update_user_scim_by_id(user_id, provider, user_data.externalId, db=db) new_user = Users.get_user_by_id(user_id, db=db) return user_to_scim(new_user, request, db=db) @@ -711,9 +707,7 @@ async def update_user( # Update externalId in the scim field if user_data.externalId: provider = get_scim_provider() - Users.update_user_scim_by_id( - user_id, provider, user_data.externalId, db=db - ) + Users.update_user_scim_by_id(user_id, provider, user_data.externalId, db=db) updated_user = Users.get_user_by_id(user_id, db=db) return user_to_scim(updated_user, request, db=db) @@ -755,9 +749,7 @@ async def patch_user( update_data["name"] = value elif path == "externalId": provider = get_scim_provider() - Users.update_user_scim_by_id( - user_id, provider, value, db=db - ) + Users.update_user_scim_by_id(user_id, provider, value, db=db) # Update user if update_data: diff --git a/backend/open_webui/routers/skills.py b/backend/open_webui/routers/skills.py index 2a51b993c8..fb7b01b87f 100644 --- a/backend/open_webui/routers/skills.py +++ b/backend/open_webui/routers/skills.py @@ -354,7 +354,10 @@ async def update_skill_access_by_id( form_data.access_grants = [ grant for grant in form_data.access_grants - if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*") + if not ( + grant.get("principal_type") == "user" + and grant.get("principal_id") == "*" + ) ] AccessGrants.set_access_grants("skill", id, form_data.access_grants, db=db) diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 60fecbb6fc..6657b34462 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -568,7 +568,10 @@ async def update_tool_access_by_id( form_data.access_grants = [ grant for grant in form_data.access_grants - if not (grant.get("principal_type") == "user" and grant.get("principal_id") == "*") + if not ( + grant.get("principal_type") == "user" + and grant.get("principal_id") == "*" + ) ] AccessGrants.set_access_grants("tool", id, form_data.access_grants, db=db)