diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 6ef3c16091..b3a332adee 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -337,7 +337,7 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) return models -def check_model_access(user, model): +def check_model_access(user, model, db=None): if model.get("arena"): if not has_access( user.id, @@ -345,16 +345,17 @@ def check_model_access(user, model): access_control=model.get("info", {}) .get("meta", {}) .get("access_control", {}), + db=db, ): raise Exception("Model not found") else: - model_info = Models.get_model_by_id(model.get("id")) + model_info = Models.get_model_by_id(model.get("id"), db=db) if not model_info: raise Exception("Model not found") elif not ( user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, type="read", access_control=model_info.access_control, db=db ) ): raise Exception("Model not found") @@ -373,7 +374,9 @@ def get_filtered_models(models, user, db=None): } filtered_models = [] - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id, db=db) + } for model in models: if model.get("arena"): if has_access(