mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-15 19:37:47 +01:00
enh: access control
This commit is contained in:
@@ -7,6 +7,8 @@ from open_webui.apps.webui.models.groups import Groups
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
####################
|
||||
# Prompts DB Schema
|
||||
####################
|
||||
@@ -107,58 +109,12 @@ class PromptsTable:
|
||||
) -> list[PromptModel]:
|
||||
prompts = self.get_prompts()
|
||||
|
||||
groups = Groups.get_groups_by_member_id(user_id)
|
||||
group_ids = [group.id for group in groups]
|
||||
|
||||
if permission == "write":
|
||||
return [
|
||||
prompt
|
||||
for prompt in prompts
|
||||
if prompt.user_id == user_id
|
||||
or (
|
||||
prompt.access_control
|
||||
and (
|
||||
any(
|
||||
group_id
|
||||
in prompt.access_control.get(permission, {}).get(
|
||||
"group_ids", []
|
||||
)
|
||||
for group_id in group_ids
|
||||
)
|
||||
or (
|
||||
user_id
|
||||
in prompt.access_control.get(permission, {}).get(
|
||||
"user_ids", []
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
elif permission == "read":
|
||||
return [
|
||||
prompt
|
||||
for prompt in prompts
|
||||
if prompt.user_id == user_id
|
||||
or prompt.access_control is None
|
||||
or (
|
||||
prompt.access_control
|
||||
and (
|
||||
any(
|
||||
prompt.access_control.get(permission, {}).get(
|
||||
"group_ids", []
|
||||
)
|
||||
in group_id
|
||||
for group_id in group_ids
|
||||
)
|
||||
or (
|
||||
user_id
|
||||
in prompt.access_control.get(permission, {}).get(
|
||||
"user_ids", []
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
return [
|
||||
prompt
|
||||
for prompt in prompts
|
||||
if prompt.user_id == user_id
|
||||
or has_access(user_id, permission, prompt.access_control)
|
||||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
self, command: str, form_data: PromptForm
|
||||
|
||||
@@ -8,6 +8,9 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
@@ -133,6 +136,18 @@ class ToolsTable:
|
||||
with get_db() as db:
|
||||
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
|
||||
|
||||
def get_tools_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[ToolModel]:
|
||||
tools = self.get_tools()
|
||||
|
||||
return [
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.user_id == user_id
|
||||
or has_access(tool.access_control, user_id, permission)
|
||||
]
|
||||
|
||||
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
@@ -14,7 +14,22 @@ router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=list[PromptModel])
|
||||
async def get_prompts(user=Depends(get_verified_user)):
|
||||
return Prompts.get_prompts()
|
||||
if user.role == "admin":
|
||||
prompts = Prompts.get_prompts()
|
||||
else:
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
@router.get("/list", response_model=list[PromptModel])
|
||||
async def get_prompt_list(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
prompts = Prompts.get_prompts()
|
||||
else:
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
############################
|
||||
@@ -23,7 +38,7 @@ async def get_prompts(user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[PromptModel])
|
||||
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
|
||||
async def create_new_prompt(form_data: PromptForm, user=Depends(get_verified_user)):
|
||||
prompt = Prompts.get_prompt_by_command(form_data.command)
|
||||
if prompt is None:
|
||||
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
||||
@@ -67,7 +82,7 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
||||
async def update_prompt_by_command(
|
||||
command: str,
|
||||
form_data: PromptForm,
|
||||
user=Depends(get_admin_user),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
|
||||
if prompt:
|
||||
@@ -85,6 +100,6 @@ async def update_prompt_by_command(
|
||||
|
||||
|
||||
@router.delete("/command/{command}/delete", response_model=bool)
|
||||
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
|
||||
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
||||
result = Prompts.delete_prompt_by_command(f"/{command}")
|
||||
return result
|
||||
|
||||
@@ -14,37 +14,54 @@ from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetToolkits
|
||||
# GetTools
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ToolResponse])
|
||||
async def get_toolkits(user=Depends(get_verified_user)):
|
||||
toolkits = [toolkit for toolkit in Tools.get_tools()]
|
||||
return toolkits
|
||||
async def get_tools(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
tools = Tools.get_tools()
|
||||
else:
|
||||
tools = Tools.get_tools_by_user_id(user.id, "read")
|
||||
return tools
|
||||
|
||||
|
||||
############################
|
||||
# ExportToolKits
|
||||
# GetToolList
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/list", response_model=list[ToolResponse])
|
||||
async def get_tool_list(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
tools = Tools.get_tools()
|
||||
else:
|
||||
tools = Tools.get_tools_by_user_id(user.id, "write")
|
||||
return tools
|
||||
|
||||
|
||||
############################
|
||||
# ExportTools
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/export", response_model=list[ToolModel])
|
||||
async def get_toolkits(user=Depends(get_admin_user)):
|
||||
toolkits = [toolkit for toolkit in Tools.get_tools()]
|
||||
return toolkits
|
||||
async def export_tools(user=Depends(get_admin_user)):
|
||||
tools = Tools.get_tools()
|
||||
return tools
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewToolKit
|
||||
# CreateNewTools
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[ToolResponse])
|
||||
async def create_new_toolkit(
|
||||
async def create_new_tools(
|
||||
request: Request,
|
||||
form_data: ToolForm,
|
||||
user=Depends(get_admin_user),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if not form_data.id.isidentifier():
|
||||
raise HTTPException(
|
||||
@@ -93,12 +110,12 @@ async def create_new_toolkit(
|
||||
|
||||
|
||||
############################
|
||||
# GetToolkitById
|
||||
# GetToolsById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/id/{id}", response_model=Optional[ToolModel])
|
||||
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
|
||||
async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
|
||||
if toolkit:
|
||||
@@ -111,16 +128,16 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
|
||||
|
||||
|
||||
############################
|
||||
# UpdateToolkitById
|
||||
# UpdateToolsById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
|
||||
async def update_toolkit_by_id(
|
||||
async def update_tools_by_id(
|
||||
request: Request,
|
||||
id: str,
|
||||
form_data: ToolForm,
|
||||
user=Depends(get_admin_user),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
@@ -158,12 +175,14 @@ async def update_toolkit_by_id(
|
||||
|
||||
|
||||
############################
|
||||
# DeleteToolkitById
|
||||
# DeleteToolsById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/id/{id}/delete", response_model=bool)
|
||||
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
|
||||
async def delete_tools_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
result = Tools.delete_tool_by_id(id)
|
||||
|
||||
if result:
|
||||
@@ -180,7 +199,7 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
||||
async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
if toolkit:
|
||||
try:
|
||||
@@ -204,8 +223,8 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
||||
async def get_toolkit_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_admin_user)
|
||||
async def get_tools_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
if toolkit:
|
||||
@@ -232,8 +251,8 @@ async def get_toolkit_valves_spec_by_id(
|
||||
|
||||
|
||||
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
||||
async def update_toolkit_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
|
||||
async def update_tools_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
if toolkit:
|
||||
@@ -276,7 +295,7 @@ async def update_toolkit_valves_by_id(
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
||||
async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
if toolkit:
|
||||
try:
|
||||
@@ -295,7 +314,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
||||
async def get_toolkit_user_valves_spec_by_id(
|
||||
async def get_tools_user_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
@@ -318,7 +337,7 @@ async def get_toolkit_user_valves_spec_by_id(
|
||||
|
||||
|
||||
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
||||
async def update_toolkit_user_valves_by_id(
|
||||
async def update_tools_user_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
|
||||
Reference in New Issue
Block a user