This commit is contained in:
Timothy Jaeryang Baek
2026-02-13 13:56:29 -06:00
parent 20de5a87da
commit 589c4e64c1
7 changed files with 178 additions and 71 deletions

View File

@@ -515,6 +515,62 @@ class AccessGrantsTable:
)
return exists is not None
def get_accessible_resource_ids(
self,
user_id: str,
resource_type: str,
resource_ids: list[str],
permission: str = "read",
user_group_ids: Optional[set[str]] = None,
db: Optional[Session] = None,
) -> set[str]:
"""
Batch check: return the subset of resource_ids that the user can access.
This replaces calling has_access() in a loop (N+1) with a single query.
"""
if not resource_ids:
return set()
with get_db_context(db) as db:
conditions = [
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_id == "*",
),
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_id == user_id,
),
]
if user_group_ids is None:
from open_webui.models.groups import Groups
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
user_group_ids = {group.id for group in user_groups}
if user_group_ids:
conditions.append(
and_(
AccessGrant.principal_type == "group",
AccessGrant.principal_id.in_(user_group_ids),
)
)
rows = (
db.query(AccessGrant.resource_id)
.filter(
AccessGrant.resource_type == resource_type,
AccessGrant.resource_id.in_(resource_ids),
AccessGrant.permission == permission,
or_(*conditions),
)
.distinct()
.all()
)
return {row[0] for row in rows}
def get_users_with_access(
self,
resource_type: str,

View File

@@ -115,8 +115,10 @@ async def get_knowledge_bases(
skip = (page - 1) * limit
filter = {}
groups = Groups.get_groups_by_member_id(user.id, db=db)
user_group_ids = {group.id for group in groups}
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups:
filter["group_ids"] = [group.id for group in groups]
@@ -126,6 +128,17 @@ async def get_knowledge_bases(
user.id, filter=filter, skip=skip, limit=limit, db=db
)
# Batch-fetch writable knowledge IDs in a single query instead of N has_access calls
knowledge_base_ids = [knowledge_base.id for knowledge_base in result.items]
writable_knowledge_base_ids = AccessGrants.get_accessible_resource_ids(
user_id=user.id,
resource_type="knowledge",
resource_ids=knowledge_base_ids,
permission="write",
user_group_ids=user_group_ids,
db=db,
)
return KnowledgeAccessListResponse(
items=[
KnowledgeAccessResponse(
@@ -133,13 +146,7 @@ async def get_knowledge_bases(
write_access=(
user.id == knowledge_base.user_id
or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or AccessGrants.has_access(
user_id=user.id,
resource_type="knowledge",
resource_id=knowledge_base.id,
permission="write",
db=db,
)
or knowledge_base.id in writable_knowledge_base_ids
),
)
for knowledge_base in result.items
@@ -166,8 +173,10 @@ async def search_knowledge_bases(
if view_option:
filter["view_option"] = view_option
groups = Groups.get_groups_by_member_id(user.id, db=db)
user_group_ids = {group.id for group in groups}
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups:
filter["group_ids"] = [group.id for group in groups]
@@ -177,6 +186,17 @@ async def search_knowledge_bases(
user.id, filter=filter, skip=skip, limit=limit, db=db
)
# Batch-fetch writable knowledge IDs in a single query instead of N has_access calls
knowledge_base_ids = [knowledge_base.id for knowledge_base in result.items]
writable_knowledge_base_ids = AccessGrants.get_accessible_resource_ids(
user_id=user.id,
resource_type="knowledge",
resource_ids=knowledge_base_ids,
permission="write",
user_group_ids=user_group_ids,
db=db,
)
return KnowledgeAccessListResponse(
items=[
KnowledgeAccessResponse(
@@ -184,13 +204,7 @@ async def search_knowledge_bases(
write_access=(
user.id == knowledge_base.user_id
or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or AccessGrants.has_access(
user_id=user.id,
resource_type="knowledge",
resource_id=knowledge_base.id,
permission="write",
db=db,
)
or knowledge_base.id in writable_knowledge_base_ids
),
)
for knowledge_base in result.items

View File

@@ -96,6 +96,17 @@ async def get_models(
result = Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db)
# Batch-fetch writable model IDs in a single query instead of N has_access calls
model_ids = [model.id for model in result.items]
writable_model_ids = AccessGrants.get_accessible_resource_ids(
user_id=user.id,
resource_type="model",
resource_ids=model_ids,
permission="write",
user_group_ids=user_group_ids,
db=db,
)
return ModelAccessListResponse(
items=[
ModelAccessResponse(
@@ -103,14 +114,7 @@ async def get_models(
write_access=(
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == model.user_id
or AccessGrants.has_access(
user_id=user.id,
resource_type="model",
resource_id=model.id,
permission="write",
user_group_ids=user_group_ids,
db=db,
)
or model.id in writable_model_ids
),
)
for model in result.items

View File

@@ -418,21 +418,24 @@ async def get_all_models(request: Request, user: UserModel = None):
async def get_filtered_models(models, user, db=None):
# Filter models based on user access control
model_ids = [model["model"] for model in models.get("models", [])]
model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)}
user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)}
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
user_id=user.id,
resource_type="model",
resource_ids=list(model_infos.keys()),
permission="read",
user_group_ids=user_group_ids,
db=db,
)
filtered_models = []
for model in models.get("models", []):
model_info = model_infos.get(model["model"])
if model_info:
if user.id == model_info.user_id or AccessGrants.has_access(
user_id=user.id,
resource_type="model",
resource_id=model_info.id,
permission="read",
user_group_ids=user_group_ids,
db=db,
):
if user.id == model_info.user_id or model_info.id in accessible_model_ids:
filtered_models.append(model)
return filtered_models
@@ -1329,6 +1332,9 @@ async def generate_chat_completion(
# Check if user has access to the model
if not bypass_filter and user.role == "user":
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id)
}
if not (
user.id == model_info.user_id
or AccessGrants.has_access(
@@ -1336,6 +1342,7 @@ async def generate_chat_completion(
resource_type="model",
resource_id=model_info.id,
permission="read",
user_group_ids=user_group_ids,
)
):
raise HTTPException(
@@ -1436,6 +1443,9 @@ async def generate_openai_completion(
# Check if user has access to the model
if user.role == "user":
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id)
}
if not (
user.id == model_info.user_id
or AccessGrants.has_access(
@@ -1443,6 +1453,7 @@ async def generate_openai_completion(
resource_type="model",
resource_id=model_info.id,
permission="read",
user_group_ids=user_group_ids,
)
):
raise HTTPException(
@@ -1520,6 +1531,9 @@ async def generate_openai_chat_completion(
# Check if user has access to the model
if user.role == "user":
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id)
}
if not (
user.id == model_info.user_id
or AccessGrants.has_access(
@@ -1527,6 +1541,7 @@ async def generate_openai_chat_completion(
resource_type="model",
resource_id=model_info.id,
permission="read",
user_group_ids=user_group_ids,
)
):
raise HTTPException(
@@ -1618,21 +1633,24 @@ async def get_openai_models(
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
# Filter models based on user access control
model_ids = [model["id"] for model in models]
model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)}
user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)}
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
user_id=user.id,
resource_type="model",
resource_ids=list(model_infos.keys()),
permission="read",
user_group_ids=user_group_ids,
db=db,
)
filtered_models = []
for model in models:
model_info = model_infos.get(model["id"])
if model_info:
if user.id == model_info.user_id or AccessGrants.has_access(
user_id=user.id,
resource_type="model",
resource_id=model_info.id,
permission="read",
user_group_ids=user_group_ids,
db=db,
):
if user.id == model_info.user_id or model_info.id in accessible_model_ids:
filtered_models.append(model)
models = filtered_models

View File

@@ -455,21 +455,24 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
async def get_filtered_models(models, user, db=None):
# Filter models based on user access control
model_ids = [model["id"] for model in models.get("data", [])]
model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)}
user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)}
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
user_id=user.id,
resource_type="model",
resource_ids=list(model_infos.keys()),
permission="read",
user_group_ids=user_group_ids,
db=db,
)
filtered_models = []
for model in models.get("data", []):
model_info = model_infos.get(model["id"])
if model_info:
if user.id == model_info.user_id or AccessGrants.has_access(
user_id=user.id,
resource_type="model",
resource_id=model_info.id,
permission="read",
user_group_ids=user_group_ids,
db=db,
):
if user.id == model_info.user_id or model_info.id in accessible_model_ids:
filtered_models.append(model)
return filtered_models
@@ -960,6 +963,9 @@ async def generate_chat_completion(
# Check if user has access to the model
if not bypass_filter and user.role == "user":
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id)
}
if not (
user.id == model_info.user_id
or AccessGrants.has_access(
@@ -967,6 +973,7 @@ async def generate_chat_completion(
resource_type="model",
resource_id=model_info.id,
permission="read",
user_group_ids=user_group_ids,
)
):
raise HTTPException(

View File

@@ -114,6 +114,17 @@ async def get_prompt_list(
user.id, filter=filter, skip=skip, limit=limit, db=db
)
# Batch-fetch writable prompt IDs in a single query instead of N has_access calls
prompt_ids = [prompt.id for prompt in result.items]
writable_prompt_ids = AccessGrants.get_accessible_resource_ids(
user_id=user.id,
resource_type="prompt",
resource_ids=prompt_ids,
permission="write",
user_group_ids=user_group_ids,
db=db,
)
return PromptAccessListResponse(
items=[
PromptAccessResponse(
@@ -121,14 +132,7 @@ async def get_prompt_list(
write_access=(
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == prompt.user_id
or AccessGrants.has_access(
user_id=user.id,
resource_type="prompt",
resource_id=prompt.id,
permission="write",
user_group_ids=user_group_ids,
db=db,
)
or prompt.id in writable_prompt_ids
),
)
for prompt in result.items

View File

@@ -377,10 +377,21 @@ def get_filtered_models(models, user, db=None):
for model_info in Models.get_models_by_ids(model_ids)
}
filtered_models = []
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
}
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
user_id=user.id,
resource_type="model",
resource_ids=list(model_infos.keys()),
permission="read",
user_group_ids=user_group_ids,
db=db,
)
filtered_models = []
for model in models:
if model.get("arena"):
meta = model.get("info", {}).get("meta", {})
@@ -399,14 +410,7 @@ def get_filtered_models(models, user, db=None):
if (
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == model_info.user_id
or AccessGrants.has_access(
user_id=user.id,
resource_type="model",
resource_id=model_info.id,
permission="read",
user_group_ids=user_group_ids,
db=db,
)
or model_info.id in accessible_model_ids
):
filtered_models.append(model)