mirror of
https://github.com/open-webui/open-webui.git
synced 2026-02-24 12:11:56 +01:00
refac
This commit is contained in:
@@ -93,7 +93,7 @@ from open_webui.utils.misc import (
|
||||
convert_logit_bias_input_to_json,
|
||||
get_content_from_message,
|
||||
)
|
||||
from open_webui.utils.tools import get_tools, get_updated_tool_function
|
||||
from open_webui.utils.tools import get_tools, get_updated_tool_function, has_tool_server_access
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
@@ -1663,6 +1663,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
log.error(f"MCP server with id {server_id} not found")
|
||||
continue
|
||||
|
||||
# Check access control for MCP server
|
||||
if not has_tool_server_access(user, mcp_server_connection):
|
||||
log.warning(f"Access denied to MCP server {server_id} for user {user.id}")
|
||||
continue
|
||||
|
||||
auth_type = mcp_server_connection.get("auth_type", "")
|
||||
headers = {}
|
||||
if auth_type == "bearer":
|
||||
|
||||
@@ -37,7 +37,10 @@ from langchain_core.utils.function_calling import (
|
||||
from open_webui.utils.misc import is_string_allowed
|
||||
from open_webui.models.tools import Tools
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.utils.plugin import load_tool_module_by_id
|
||||
from open_webui.utils.access_control import has_access
|
||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
|
||||
@@ -130,15 +133,114 @@ def get_updated_tool_function(function: Callable, extra_params: dict):
|
||||
return function
|
||||
|
||||
|
||||
def has_tool_server_access(
|
||||
user: UserModel, server_connection: dict, user_group_ids: set = None
|
||||
) -> bool:
|
||||
"""Check if user has access to a tool server (MCP or OpenAPI)."""
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
return True
|
||||
|
||||
if user_group_ids is None:
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
||||
|
||||
access_control = server_connection.get("config", {}).get("access_control", None)
|
||||
return has_access(user.id, "read", access_control, user_group_ids)
|
||||
|
||||
|
||||
async def get_tools(
|
||||
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
|
||||
) -> dict[str, dict]:
|
||||
"""Load tools for the given tool_ids, checking access control."""
|
||||
tools_dict = {}
|
||||
|
||||
# Get user's group memberships for access control checks
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
||||
|
||||
for tool_id in tool_ids:
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
if tool is None:
|
||||
if tool:
|
||||
# Check access control for local tools
|
||||
if (
|
||||
not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
and tool.user_id != user.id
|
||||
and not has_access(user.id, "read", tool.access_control, user_group_ids)
|
||||
):
|
||||
log.warning(f"Access denied to tool {tool_id} for user {user.id}")
|
||||
continue
|
||||
|
||||
module = request.app.state.TOOLS.get(tool_id, None)
|
||||
if module is None:
|
||||
module, _ = load_tool_module_by_id(tool_id)
|
||||
request.app.state.TOOLS[tool_id] = module
|
||||
|
||||
__user__ = {
|
||||
**extra_params["__user__"],
|
||||
}
|
||||
|
||||
# Set valves for the tool
|
||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||
module.valves = module.Valves(**valves)
|
||||
if hasattr(module, "UserValves"):
|
||||
__user__["valves"] = module.UserValves( # type: ignore
|
||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||
)
|
||||
|
||||
for spec in tool.specs:
|
||||
# TODO: Fix hack for OpenAI API
|
||||
# Some times breaks OpenAI but others don't. Leaving the comment
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val.get("type") == "str":
|
||||
val["type"] = "string"
|
||||
|
||||
# Remove internal reserved parameters (e.g. __id__, __user__)
|
||||
spec["parameters"]["properties"] = {
|
||||
key: val
|
||||
for key, val in spec["parameters"]["properties"].items()
|
||||
if not key.startswith("__")
|
||||
}
|
||||
|
||||
# convert to function that takes only model params and inserts custom params
|
||||
function_name = spec["name"]
|
||||
tool_function = getattr(module, function_name)
|
||||
callable = get_async_tool_function_and_apply_extra_params(
|
||||
tool_function,
|
||||
{
|
||||
**extra_params,
|
||||
"__id__": tool_id,
|
||||
"__user__": __user__,
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: Support Pydantic models as parameters
|
||||
if callable.__doc__ and callable.__doc__.strip() != "":
|
||||
s = re.split(":(param|return)", callable.__doc__, 1)
|
||||
spec["description"] = s[0]
|
||||
else:
|
||||
spec["description"] = function_name
|
||||
|
||||
tool_dict = {
|
||||
"tool_id": tool_id,
|
||||
"callable": callable,
|
||||
"spec": spec,
|
||||
# Misc info
|
||||
"metadata": {
|
||||
"file_handler": hasattr(module, "file_handler")
|
||||
and module.file_handler,
|
||||
"citation": hasattr(module, "citation") and module.citation,
|
||||
},
|
||||
}
|
||||
|
||||
# Handle function name collisions
|
||||
while function_name in tools_dict:
|
||||
log.warning(
|
||||
f"Tool {function_name} already exists in another tools!"
|
||||
)
|
||||
# Prepend tool ID to function name
|
||||
function_name = f"{tool_id}_{function_name}"
|
||||
|
||||
tools_dict[function_name] = tool_dict
|
||||
else:
|
||||
if tool_id.startswith("server:"):
|
||||
splits = tool_id.split(":")
|
||||
|
||||
@@ -173,6 +275,15 @@ async def get_tools(
|
||||
]
|
||||
)
|
||||
|
||||
# Check access control for tool server
|
||||
if not has_tool_server_access(
|
||||
user, tool_server_connection, user_group_ids
|
||||
):
|
||||
log.warning(
|
||||
f"Access denied to tool server {server_id} for user {user.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
specs = tool_server_data.get("specs", [])
|
||||
function_name_filter_list = tool_server_connection.get(
|
||||
"config", {}
|
||||
@@ -267,82 +378,6 @@ async def get_tools(
|
||||
else:
|
||||
continue
|
||||
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
module = request.app.state.TOOLS.get(tool_id, None)
|
||||
if module is None:
|
||||
module, _ = load_tool_module_by_id(tool_id)
|
||||
request.app.state.TOOLS[tool_id] = module
|
||||
|
||||
__user__ = {
|
||||
**extra_params["__user__"],
|
||||
}
|
||||
|
||||
# Set valves for the tool
|
||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||
module.valves = module.Valves(**valves)
|
||||
if hasattr(module, "UserValves"):
|
||||
__user__["valves"] = module.UserValves( # type: ignore
|
||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||
)
|
||||
|
||||
for spec in tool.specs:
|
||||
# TODO: Fix hack for OpenAI API
|
||||
# Some times breaks OpenAI but others don't. Leaving the comment
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val.get("type") == "str":
|
||||
val["type"] = "string"
|
||||
|
||||
# Remove internal reserved parameters (e.g. __id__, __user__)
|
||||
spec["parameters"]["properties"] = {
|
||||
key: val
|
||||
for key, val in spec["parameters"]["properties"].items()
|
||||
if not key.startswith("__")
|
||||
}
|
||||
|
||||
# convert to function that takes only model params and inserts custom params
|
||||
function_name = spec["name"]
|
||||
tool_function = getattr(module, function_name)
|
||||
callable = get_async_tool_function_and_apply_extra_params(
|
||||
tool_function,
|
||||
{
|
||||
**extra_params,
|
||||
"__id__": tool_id,
|
||||
"__user__": __user__,
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: Support Pydantic models as parameters
|
||||
if callable.__doc__ and callable.__doc__.strip() != "":
|
||||
s = re.split(":(param|return)", callable.__doc__, 1)
|
||||
spec["description"] = s[0]
|
||||
else:
|
||||
spec["description"] = function_name
|
||||
|
||||
tool_dict = {
|
||||
"tool_id": tool_id,
|
||||
"callable": callable,
|
||||
"spec": spec,
|
||||
# Misc info
|
||||
"metadata": {
|
||||
"file_handler": hasattr(module, "file_handler")
|
||||
and module.file_handler,
|
||||
"citation": hasattr(module, "citation") and module.citation,
|
||||
},
|
||||
}
|
||||
|
||||
# Handle function name collisions
|
||||
while function_name in tools_dict:
|
||||
log.warning(
|
||||
f"Tool {function_name} already exists in another tools!"
|
||||
)
|
||||
# Prepend tool ID to function name
|
||||
function_name = f"{tool_id}_{function_name}"
|
||||
|
||||
tools_dict[function_name] = tool_dict
|
||||
|
||||
return tools_dict
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user