From 9b06fdc8fe1c933071610336be05f11e77e6c8eb Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 8 Jan 2026 03:37:11 +0400 Subject: [PATCH] refac --- backend/open_webui/utils/middleware.py | 7 +- backend/open_webui/utils/tools.py | 189 +++++++++++++++---------- 2 files changed, 118 insertions(+), 78 deletions(-) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index d7d8a4fcee..26883dba4a 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -93,7 +93,7 @@ from open_webui.utils.misc import ( convert_logit_bias_input_to_json, get_content_from_message, ) -from open_webui.utils.tools import get_tools, get_updated_tool_function +from open_webui.utils.tools import get_tools, get_updated_tool_function, has_tool_server_access from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.filter import ( get_sorted_filter_ids, @@ -1663,6 +1663,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): log.error(f"MCP server with id {server_id} not found") continue + # Check access control for MCP server + if not has_tool_server_access(user, mcp_server_connection): + log.warning(f"Access denied to MCP server {server_id} for user {user.id}") + continue + auth_type = mcp_server_connection.get("auth_type", "") headers = {} if auth_type == "bearer": diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 478c29c44d..700b4c6765 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -37,7 +37,10 @@ from langchain_core.utils.function_calling import ( from open_webui.utils.misc import is_string_allowed from open_webui.models.tools import Tools from open_webui.models.users import UserModel +from open_webui.models.groups import Groups from open_webui.utils.plugin import load_tool_module_by_id +from open_webui.utils.access_control import has_access +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA, @@ -130,15 +133,114 @@ def get_updated_tool_function(function: Callable, extra_params: dict): return function +def has_tool_server_access( + user: UserModel, server_connection: dict, user_group_ids: set = None +) -> bool: + """Check if user has access to a tool server (MCP or OpenAPI).""" + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + return True + + if user_group_ids is None: + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} + + access_control = server_connection.get("config", {}).get("access_control", None) + return has_access(user.id, "read", access_control, user_group_ids) + + async def get_tools( request: Request, tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: + """Load tools for the given tool_ids, checking access control.""" tools_dict = {} + # Get user's group memberships for access control checks + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} + for tool_id in tool_ids: tool = Tools.get_tool_by_id(tool_id) - if tool is None: + if tool: + # Check access control for local tools + if ( + not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + and tool.user_id != user.id + and not has_access(user.id, "read", tool.access_control, user_group_ids) + ): + log.warning(f"Access denied to tool {tool_id} for user {user.id}") + continue + module = request.app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_tool_module_by_id(tool_id) + request.app.state.TOOLS[tool_id] = module + + __user__ = { + **extra_params["__user__"], + } + + # Set valves for the tool + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + if hasattr(module, "UserValves"): + __user__["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in tool.specs: + # TODO: Fix hack for OpenAI API + # Some times breaks OpenAI but others don't. Leaving the comment + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val.get("type") == "str": + val["type"] = "string" + + # Remove internal reserved parameters (e.g. __id__, __user__) + spec["parameters"]["properties"] = { + key: val + for key, val in spec["parameters"]["properties"].items() + if not key.startswith("__") + } + + # convert to function that takes only model params and inserts custom params + function_name = spec["name"] + tool_function = getattr(module, function_name) + callable = get_async_tool_function_and_apply_extra_params( + tool_function, + { + **extra_params, + "__id__": tool_id, + "__user__": __user__, + }, + ) + + # TODO: Support Pydantic models as parameters + if callable.__doc__ and callable.__doc__.strip() != "": + s = re.split(":(param|return)", callable.__doc__, 1) + spec["description"] = s[0] + else: + spec["description"] = function_name + + tool_dict = { + "tool_id": tool_id, + "callable": callable, + "spec": spec, + # Misc info + "metadata": { + "file_handler": hasattr(module, "file_handler") + and module.file_handler, + "citation": hasattr(module, "citation") and module.citation, + }, + } + + # Handle function name collisions + while function_name in tools_dict: + log.warning( + f"Tool {function_name} already exists in another tools!" + ) + # Prepend tool ID to function name + function_name = f"{tool_id}_{function_name}" + + tools_dict[function_name] = tool_dict + else: if tool_id.startswith("server:"): splits = tool_id.split(":") @@ -173,6 +275,15 @@ async def get_tools( ] ) + # Check access control for tool server + if not has_tool_server_access( + user, tool_server_connection, user_group_ids + ): + log.warning( + f"Access denied to tool server {server_id} for user {user.id}" + ) + continue + specs = tool_server_data.get("specs", []) function_name_filter_list = tool_server_connection.get( "config", {} @@ -267,82 +378,6 @@ async def get_tools( else: continue - else: - continue - else: - module = request.app.state.TOOLS.get(tool_id, None) - if module is None: - module, _ = load_tool_module_by_id(tool_id) - request.app.state.TOOLS[tool_id] = module - - __user__ = { - **extra_params["__user__"], - } - - # Set valves for the tool - if hasattr(module, "valves") and hasattr(module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) or {} - module.valves = module.Valves(**valves) - if hasattr(module, "UserValves"): - __user__["valves"] = module.UserValves( # type: ignore - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - - for spec in tool.specs: - # TODO: Fix hack for OpenAI API - # Some times breaks OpenAI but others don't. Leaving the comment - for val in spec.get("parameters", {}).get("properties", {}).values(): - if val.get("type") == "str": - val["type"] = "string" - - # Remove internal reserved parameters (e.g. __id__, __user__) - spec["parameters"]["properties"] = { - key: val - for key, val in spec["parameters"]["properties"].items() - if not key.startswith("__") - } - - # convert to function that takes only model params and inserts custom params - function_name = spec["name"] - tool_function = getattr(module, function_name) - callable = get_async_tool_function_and_apply_extra_params( - tool_function, - { - **extra_params, - "__id__": tool_id, - "__user__": __user__, - }, - ) - - # TODO: Support Pydantic models as parameters - if callable.__doc__ and callable.__doc__.strip() != "": - s = re.split(":(param|return)", callable.__doc__, 1) - spec["description"] = s[0] - else: - spec["description"] = function_name - - tool_dict = { - "tool_id": tool_id, - "callable": callable, - "spec": spec, - # Misc info - "metadata": { - "file_handler": hasattr(module, "file_handler") - and module.file_handler, - "citation": hasattr(module, "citation") and module.citation, - }, - } - - # Handle function name collisions - while function_name in tools_dict: - log.warning( - f"Tool {function_name} already exists in another tools!" - ) - # Prepend tool ID to function name - function_name = f"{tool_id}_{function_name}" - - tools_dict[function_name] = tool_dict - return tools_dict