From 589c4e64c1b7bb7a7a5abc20382b92fb860e28c2 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 13 Feb 2026 13:56:29 -0600 Subject: [PATCH] refac --- backend/open_webui/models/access_grants.py | 56 +++++++++++++++++++++ backend/open_webui/routers/knowledge.py | 46 +++++++++++------ backend/open_webui/routers/models.py | 20 +++++--- backend/open_webui/routers/ollama.py | 58 ++++++++++++++-------- backend/open_webui/routers/openai.py | 27 ++++++---- backend/open_webui/routers/prompts.py | 20 +++++--- backend/open_webui/utils/models.py | 22 ++++---- 7 files changed, 178 insertions(+), 71 deletions(-) diff --git a/backend/open_webui/models/access_grants.py b/backend/open_webui/models/access_grants.py index fa6e79a8db..dd5a344b46 100644 --- a/backend/open_webui/models/access_grants.py +++ b/backend/open_webui/models/access_grants.py @@ -515,6 +515,62 @@ class AccessGrantsTable: ) return exists is not None + def get_accessible_resource_ids( + self, + user_id: str, + resource_type: str, + resource_ids: list[str], + permission: str = "read", + user_group_ids: Optional[set[str]] = None, + db: Optional[Session] = None, + ) -> set[str]: + """ + Batch check: return the subset of resource_ids that the user can access. + + This replaces calling has_access() in a loop (N+1) with a single query. + """ + if not resource_ids: + return set() + + with get_db_context(db) as db: + conditions = [ + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == "*", + ), + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ), + ] + + if user_group_ids is None: + from open_webui.models.groups import Groups + + user_groups = Groups.get_groups_by_member_id(user_id, db=db) + user_group_ids = {group.id for group in user_groups} + + if user_group_ids: + conditions.append( + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(user_group_ids), + ) + ) + + rows = ( + db.query(AccessGrant.resource_id) + .filter( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id.in_(resource_ids), + AccessGrant.permission == permission, + or_(*conditions), + ) + .distinct() + .all() + ) + return {row[0] for row in rows} + def get_users_with_access( self, resource_type: str, diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index d4b1e0a803..d620c1745f 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -115,8 +115,10 @@ async def get_knowledge_bases( skip = (page - 1) * limit filter = {} + groups = Groups.get_groups_by_member_id(user.id, db=db) + user_group_ids = {group.id for group in groups} + if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: - groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] @@ -126,6 +128,17 @@ async def get_knowledge_bases( user.id, filter=filter, skip=skip, limit=limit, db=db ) + # Batch-fetch writable knowledge IDs in a single query instead of N has_access calls + knowledge_base_ids = [knowledge_base.id for knowledge_base in result.items] + writable_knowledge_base_ids = AccessGrants.get_accessible_resource_ids( + user_id=user.id, + resource_type="knowledge", + resource_ids=knowledge_base_ids, + permission="write", + user_group_ids=user_group_ids, + db=db, + ) + return KnowledgeAccessListResponse( items=[ KnowledgeAccessResponse( @@ -133,13 +146,7 @@ async def get_knowledge_bases( write_access=( user.id == knowledge_base.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or AccessGrants.has_access( - user_id=user.id, - resource_type="knowledge", - resource_id=knowledge_base.id, - permission="write", - db=db, - ) + or knowledge_base.id in writable_knowledge_base_ids ), ) for knowledge_base in result.items @@ -166,8 +173,10 @@ async def search_knowledge_bases( if view_option: filter["view_option"] = view_option + groups = Groups.get_groups_by_member_id(user.id, db=db) + user_group_ids = {group.id for group in groups} + if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: - groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] @@ -177,6 +186,17 @@ async def search_knowledge_bases( user.id, filter=filter, skip=skip, limit=limit, db=db ) + # Batch-fetch writable knowledge IDs in a single query instead of N has_access calls + knowledge_base_ids = [knowledge_base.id for knowledge_base in result.items] + writable_knowledge_base_ids = AccessGrants.get_accessible_resource_ids( + user_id=user.id, + resource_type="knowledge", + resource_ids=knowledge_base_ids, + permission="write", + user_group_ids=user_group_ids, + db=db, + ) + return KnowledgeAccessListResponse( items=[ KnowledgeAccessResponse( @@ -184,13 +204,7 @@ async def search_knowledge_bases( write_access=( user.id == knowledge_base.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or AccessGrants.has_access( - user_id=user.id, - resource_type="knowledge", - resource_id=knowledge_base.id, - permission="write", - db=db, - ) + or knowledge_base.id in writable_knowledge_base_ids ), ) for knowledge_base in result.items diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 43065faa13..fe4137cab9 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -96,6 +96,17 @@ async def get_models( result = Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db) + # Batch-fetch writable model IDs in a single query instead of N has_access calls + model_ids = [model.id for model in result.items] + writable_model_ids = AccessGrants.get_accessible_resource_ids( + user_id=user.id, + resource_type="model", + resource_ids=model_ids, + permission="write", + user_group_ids=user_group_ids, + db=db, + ) + return ModelAccessListResponse( items=[ ModelAccessResponse( @@ -103,14 +114,7 @@ async def get_models( write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model.user_id - or AccessGrants.has_access( - user_id=user.id, - resource_type="model", - resource_id=model.id, - permission="write", - user_group_ids=user_group_ids, - db=db, - ) + or model.id in writable_model_ids ), ) for model in result.items diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 7356a2b0ba..394735e898 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -418,21 +418,24 @@ async def get_all_models(request: Request, user: UserModel = None): async def get_filtered_models(models, user, db=None): # Filter models based on user access control model_ids = [model["model"] for model in models.get("models", [])] - model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)} - user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)} + model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)} + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} + + # Batch-fetch accessible resource IDs in a single query instead of N has_access calls + accessible_model_ids = AccessGrants.get_accessible_resource_ids( + user_id=user.id, + resource_type="model", + resource_ids=list(model_infos.keys()), + permission="read", + user_group_ids=user_group_ids, + db=db, + ) filtered_models = [] for model in models.get("models", []): model_info = model_infos.get(model["model"]) if model_info: - if user.id == model_info.user_id or AccessGrants.has_access( - user_id=user.id, - resource_type="model", - resource_id=model_info.id, - permission="read", - user_group_ids=user_group_ids, - db=db, - ): + if user.id == model_info.user_id or model_info.id in accessible_model_ids: filtered_models.append(model) return filtered_models @@ -1329,6 +1332,9 @@ async def generate_chat_completion( # Check if user has access to the model if not bypass_filter and user.role == "user": + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id) + } if not ( user.id == model_info.user_id or AccessGrants.has_access( @@ -1336,6 +1342,7 @@ async def generate_chat_completion( resource_type="model", resource_id=model_info.id, permission="read", + user_group_ids=user_group_ids, ) ): raise HTTPException( @@ -1436,6 +1443,9 @@ async def generate_openai_completion( # Check if user has access to the model if user.role == "user": + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id) + } if not ( user.id == model_info.user_id or AccessGrants.has_access( @@ -1443,6 +1453,7 @@ async def generate_openai_completion( resource_type="model", resource_id=model_info.id, permission="read", + user_group_ids=user_group_ids, ) ): raise HTTPException( @@ -1520,6 +1531,9 @@ async def generate_openai_chat_completion( # Check if user has access to the model if user.role == "user": + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id) + } if not ( user.id == model_info.user_id or AccessGrants.has_access( @@ -1527,6 +1541,7 @@ async def generate_openai_chat_completion( resource_type="model", resource_id=model_info.id, permission="read", + user_group_ids=user_group_ids, ) ): raise HTTPException( @@ -1618,21 +1633,24 @@ async def get_openai_models( if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: # Filter models based on user access control model_ids = [model["id"] for model in models] - model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)} - user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)} + model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)} + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} + + # Batch-fetch accessible resource IDs in a single query instead of N has_access calls + accessible_model_ids = AccessGrants.get_accessible_resource_ids( + user_id=user.id, + resource_type="model", + resource_ids=list(model_infos.keys()), + permission="read", + user_group_ids=user_group_ids, + db=db, + ) filtered_models = [] for model in models: model_info = model_infos.get(model["id"]) if model_info: - if user.id == model_info.user_id or AccessGrants.has_access( - user_id=user.id, - resource_type="model", - resource_id=model_info.id, - permission="read", - user_group_ids=user_group_ids, - db=db, - ): + if user.id == model_info.user_id or model_info.id in accessible_model_ids: filtered_models.append(model) models = filtered_models diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index de978d011a..f8688a9c93 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -455,21 +455,24 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: async def get_filtered_models(models, user, db=None): # Filter models based on user access control model_ids = [model["id"] for model in models.get("data", [])] - model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)} - user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)} + model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)} + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} + + # Batch-fetch accessible resource IDs in a single query instead of N has_access calls + accessible_model_ids = AccessGrants.get_accessible_resource_ids( + user_id=user.id, + resource_type="model", + resource_ids=list(model_infos.keys()), + permission="read", + user_group_ids=user_group_ids, + db=db, + ) filtered_models = [] for model in models.get("data", []): model_info = model_infos.get(model["id"]) if model_info: - if user.id == model_info.user_id or AccessGrants.has_access( - user_id=user.id, - resource_type="model", - resource_id=model_info.id, - permission="read", - user_group_ids=user_group_ids, - db=db, - ): + if user.id == model_info.user_id or model_info.id in accessible_model_ids: filtered_models.append(model) return filtered_models @@ -960,6 +963,9 @@ async def generate_chat_completion( # Check if user has access to the model if not bypass_filter and user.role == "user": + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id) + } if not ( user.id == model_info.user_id or AccessGrants.has_access( @@ -967,6 +973,7 @@ async def generate_chat_completion( resource_type="model", resource_id=model_info.id, permission="read", + user_group_ids=user_group_ids, ) ): raise HTTPException( diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 86d2648a88..0060ab2b18 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -114,6 +114,17 @@ async def get_prompt_list( user.id, filter=filter, skip=skip, limit=limit, db=db ) + # Batch-fetch writable prompt IDs in a single query instead of N has_access calls + prompt_ids = [prompt.id for prompt in result.items] + writable_prompt_ids = AccessGrants.get_accessible_resource_ids( + user_id=user.id, + resource_type="prompt", + resource_ids=prompt_ids, + permission="write", + user_group_ids=user_group_ids, + db=db, + ) + return PromptAccessListResponse( items=[ PromptAccessResponse( @@ -121,14 +132,7 @@ async def get_prompt_list( write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id - or AccessGrants.has_access( - user_id=user.id, - resource_type="prompt", - resource_id=prompt.id, - permission="write", - user_group_ids=user_group_ids, - db=db, - ) + or prompt.id in writable_prompt_ids ), ) for prompt in result.items diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index ff3a6e0caf..8bef1591bb 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -377,10 +377,21 @@ def get_filtered_models(models, user, db=None): for model_info in Models.get_models_by_ids(model_ids) } - filtered_models = [] user_group_ids = { group.id for group in Groups.get_groups_by_member_id(user.id, db=db) } + + # Batch-fetch accessible resource IDs in a single query instead of N has_access calls + accessible_model_ids = AccessGrants.get_accessible_resource_ids( + user_id=user.id, + resource_type="model", + resource_ids=list(model_infos.keys()), + permission="read", + user_group_ids=user_group_ids, + db=db, + ) + + filtered_models = [] for model in models: if model.get("arena"): meta = model.get("info", {}).get("meta", {}) @@ -399,14 +410,7 @@ def get_filtered_models(models, user, db=None): if ( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model_info.user_id - or AccessGrants.has_access( - user_id=user.id, - resource_type="model", - resource_id=model_info.id, - permission="read", - user_group_ids=user_group_ids, - db=db, - ) + or model_info.id in accessible_model_ids ): filtered_models.append(model)