mirror of
https://github.com/open-webui/open-webui.git
synced 2026-02-24 04:00:31 +01:00
refac
This commit is contained in:
@@ -515,6 +515,62 @@ class AccessGrantsTable:
|
||||
)
|
||||
return exists is not None
|
||||
|
||||
def get_accessible_resource_ids(
|
||||
self,
|
||||
user_id: str,
|
||||
resource_type: str,
|
||||
resource_ids: list[str],
|
||||
permission: str = "read",
|
||||
user_group_ids: Optional[set[str]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Batch check: return the subset of resource_ids that the user can access.
|
||||
|
||||
This replaces calling has_access() in a loop (N+1) with a single query.
|
||||
"""
|
||||
if not resource_ids:
|
||||
return set()
|
||||
|
||||
with get_db_context(db) as db:
|
||||
conditions = [
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_id == "*",
|
||||
),
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_id == user_id,
|
||||
),
|
||||
]
|
||||
|
||||
if user_group_ids is None:
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
user_group_ids = {group.id for group in user_groups}
|
||||
|
||||
if user_group_ids:
|
||||
conditions.append(
|
||||
and_(
|
||||
AccessGrant.principal_type == "group",
|
||||
AccessGrant.principal_id.in_(user_group_ids),
|
||||
)
|
||||
)
|
||||
|
||||
rows = (
|
||||
db.query(AccessGrant.resource_id)
|
||||
.filter(
|
||||
AccessGrant.resource_type == resource_type,
|
||||
AccessGrant.resource_id.in_(resource_ids),
|
||||
AccessGrant.permission == permission,
|
||||
or_(*conditions),
|
||||
)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
return {row[0] for row in rows}
|
||||
|
||||
def get_users_with_access(
|
||||
self,
|
||||
resource_type: str,
|
||||
|
||||
@@ -115,8 +115,10 @@ async def get_knowledge_bases(
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
user_group_ids = {group.id for group in groups}
|
||||
|
||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
@@ -126,6 +128,17 @@ async def get_knowledge_bases(
|
||||
user.id, filter=filter, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
# Batch-fetch writable knowledge IDs in a single query instead of N has_access calls
|
||||
knowledge_base_ids = [knowledge_base.id for knowledge_base in result.items]
|
||||
writable_knowledge_base_ids = AccessGrants.get_accessible_resource_ids(
|
||||
user_id=user.id,
|
||||
resource_type="knowledge",
|
||||
resource_ids=knowledge_base_ids,
|
||||
permission="write",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
|
||||
return KnowledgeAccessListResponse(
|
||||
items=[
|
||||
KnowledgeAccessResponse(
|
||||
@@ -133,13 +146,7 @@ async def get_knowledge_bases(
|
||||
write_access=(
|
||||
user.id == knowledge_base.user_id
|
||||
or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or AccessGrants.has_access(
|
||||
user_id=user.id,
|
||||
resource_type="knowledge",
|
||||
resource_id=knowledge_base.id,
|
||||
permission="write",
|
||||
db=db,
|
||||
)
|
||||
or knowledge_base.id in writable_knowledge_base_ids
|
||||
),
|
||||
)
|
||||
for knowledge_base in result.items
|
||||
@@ -166,8 +173,10 @@ async def search_knowledge_bases(
|
||||
if view_option:
|
||||
filter["view_option"] = view_option
|
||||
|
||||
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
user_group_ids = {group.id for group in groups}
|
||||
|
||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
@@ -177,6 +186,17 @@ async def search_knowledge_bases(
|
||||
user.id, filter=filter, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
# Batch-fetch writable knowledge IDs in a single query instead of N has_access calls
|
||||
knowledge_base_ids = [knowledge_base.id for knowledge_base in result.items]
|
||||
writable_knowledge_base_ids = AccessGrants.get_accessible_resource_ids(
|
||||
user_id=user.id,
|
||||
resource_type="knowledge",
|
||||
resource_ids=knowledge_base_ids,
|
||||
permission="write",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
|
||||
return KnowledgeAccessListResponse(
|
||||
items=[
|
||||
KnowledgeAccessResponse(
|
||||
@@ -184,13 +204,7 @@ async def search_knowledge_bases(
|
||||
write_access=(
|
||||
user.id == knowledge_base.user_id
|
||||
or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or AccessGrants.has_access(
|
||||
user_id=user.id,
|
||||
resource_type="knowledge",
|
||||
resource_id=knowledge_base.id,
|
||||
permission="write",
|
||||
db=db,
|
||||
)
|
||||
or knowledge_base.id in writable_knowledge_base_ids
|
||||
),
|
||||
)
|
||||
for knowledge_base in result.items
|
||||
|
||||
@@ -96,6 +96,17 @@ async def get_models(
|
||||
|
||||
result = Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db)
|
||||
|
||||
# Batch-fetch writable model IDs in a single query instead of N has_access calls
|
||||
model_ids = [model.id for model in result.items]
|
||||
writable_model_ids = AccessGrants.get_accessible_resource_ids(
|
||||
user_id=user.id,
|
||||
resource_type="model",
|
||||
resource_ids=model_ids,
|
||||
permission="write",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
|
||||
return ModelAccessListResponse(
|
||||
items=[
|
||||
ModelAccessResponse(
|
||||
@@ -103,14 +114,7 @@ async def get_models(
|
||||
write_access=(
|
||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or user.id == model.user_id
|
||||
or AccessGrants.has_access(
|
||||
user_id=user.id,
|
||||
resource_type="model",
|
||||
resource_id=model.id,
|
||||
permission="write",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
or model.id in writable_model_ids
|
||||
),
|
||||
)
|
||||
for model in result.items
|
||||
|
||||
@@ -418,21 +418,24 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
async def get_filtered_models(models, user, db=None):
|
||||
# Filter models based on user access control
|
||||
model_ids = [model["model"] for model in models.get("models", [])]
|
||||
model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)}
|
||||
user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||
|
||||
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
||||
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
|
||||
user_id=user.id,
|
||||
resource_type="model",
|
||||
resource_ids=list(model_infos.keys()),
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
|
||||
filtered_models = []
|
||||
for model in models.get("models", []):
|
||||
model_info = model_infos.get(model["model"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or AccessGrants.has_access(
|
||||
user_id=user.id,
|
||||
resource_type="model",
|
||||
resource_id=model_info.id,
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
):
|
||||
if user.id == model_info.user_id or model_info.id in accessible_model_ids:
|
||||
filtered_models.append(model)
|
||||
return filtered_models
|
||||
|
||||
@@ -1329,6 +1332,9 @@ async def generate_chat_completion(
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id)
|
||||
}
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or AccessGrants.has_access(
|
||||
@@ -1336,6 +1342,7 @@ async def generate_chat_completion(
|
||||
resource_type="model",
|
||||
resource_id=model_info.id,
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
@@ -1436,6 +1443,9 @@ async def generate_openai_completion(
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id)
|
||||
}
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or AccessGrants.has_access(
|
||||
@@ -1443,6 +1453,7 @@ async def generate_openai_completion(
|
||||
resource_type="model",
|
||||
resource_id=model_info.id,
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
@@ -1520,6 +1531,9 @@ async def generate_openai_chat_completion(
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id)
|
||||
}
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or AccessGrants.has_access(
|
||||
@@ -1527,6 +1541,7 @@ async def generate_openai_chat_completion(
|
||||
resource_type="model",
|
||||
resource_id=model_info.id,
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
@@ -1618,21 +1633,24 @@ async def get_openai_models(
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
# Filter models based on user access control
|
||||
model_ids = [model["id"] for model in models]
|
||||
model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)}
|
||||
user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||
|
||||
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
||||
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
|
||||
user_id=user.id,
|
||||
resource_type="model",
|
||||
resource_ids=list(model_infos.keys()),
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
|
||||
filtered_models = []
|
||||
for model in models:
|
||||
model_info = model_infos.get(model["id"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or AccessGrants.has_access(
|
||||
user_id=user.id,
|
||||
resource_type="model",
|
||||
resource_id=model_info.id,
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
):
|
||||
if user.id == model_info.user_id or model_info.id in accessible_model_ids:
|
||||
filtered_models.append(model)
|
||||
models = filtered_models
|
||||
|
||||
|
||||
@@ -455,21 +455,24 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||
async def get_filtered_models(models, user, db=None):
|
||||
# Filter models based on user access control
|
||||
model_ids = [model["id"] for model in models.get("data", [])]
|
||||
model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)}
|
||||
user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||
|
||||
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
||||
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
|
||||
user_id=user.id,
|
||||
resource_type="model",
|
||||
resource_ids=list(model_infos.keys()),
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
|
||||
filtered_models = []
|
||||
for model in models.get("data", []):
|
||||
model_info = model_infos.get(model["id"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or AccessGrants.has_access(
|
||||
user_id=user.id,
|
||||
resource_type="model",
|
||||
resource_id=model_info.id,
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
):
|
||||
if user.id == model_info.user_id or model_info.id in accessible_model_ids:
|
||||
filtered_models.append(model)
|
||||
return filtered_models
|
||||
|
||||
@@ -960,6 +963,9 @@ async def generate_chat_completion(
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id)
|
||||
}
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or AccessGrants.has_access(
|
||||
@@ -967,6 +973,7 @@ async def generate_chat_completion(
|
||||
resource_type="model",
|
||||
resource_id=model_info.id,
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
|
||||
@@ -114,6 +114,17 @@ async def get_prompt_list(
|
||||
user.id, filter=filter, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
# Batch-fetch writable prompt IDs in a single query instead of N has_access calls
|
||||
prompt_ids = [prompt.id for prompt in result.items]
|
||||
writable_prompt_ids = AccessGrants.get_accessible_resource_ids(
|
||||
user_id=user.id,
|
||||
resource_type="prompt",
|
||||
resource_ids=prompt_ids,
|
||||
permission="write",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
|
||||
return PromptAccessListResponse(
|
||||
items=[
|
||||
PromptAccessResponse(
|
||||
@@ -121,14 +132,7 @@ async def get_prompt_list(
|
||||
write_access=(
|
||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or user.id == prompt.user_id
|
||||
or AccessGrants.has_access(
|
||||
user_id=user.id,
|
||||
resource_type="prompt",
|
||||
resource_id=prompt.id,
|
||||
permission="write",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
or prompt.id in writable_prompt_ids
|
||||
),
|
||||
)
|
||||
for prompt in result.items
|
||||
|
||||
@@ -377,10 +377,21 @@ def get_filtered_models(models, user, db=None):
|
||||
for model_info in Models.get_models_by_ids(model_ids)
|
||||
}
|
||||
|
||||
filtered_models = []
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
|
||||
}
|
||||
|
||||
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
||||
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
|
||||
user_id=user.id,
|
||||
resource_type="model",
|
||||
resource_ids=list(model_infos.keys()),
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
|
||||
filtered_models = []
|
||||
for model in models:
|
||||
if model.get("arena"):
|
||||
meta = model.get("info", {}).get("meta", {})
|
||||
@@ -399,14 +410,7 @@ def get_filtered_models(models, user, db=None):
|
||||
if (
|
||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or user.id == model_info.user_id
|
||||
or AccessGrants.has_access(
|
||||
user_id=user.id,
|
||||
resource_type="model",
|
||||
resource_id=model_info.id,
|
||||
permission="read",
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
)
|
||||
or model_info.id in accessible_model_ids
|
||||
):
|
||||
filtered_models.append(model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user