Files
Timothy Jaeryang Baek 176f9a7816 refac
2026-02-23 16:01:03 -06:00

437 lines
12 KiB
Python

import logging
from typing import Optional
from open_webui.models.groups import Groups
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session
from open_webui.models.skills import (
SkillForm,
SkillModel,
SkillResponse,
SkillUserResponse,
SkillAccessResponse,
SkillAccessListResponse,
Skills,
)
from open_webui.models.access_grants import AccessGrants
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
from open_webui.constants import ERROR_MESSAGES
log = logging.getLogger(__name__)
PAGE_ITEM_COUNT = 30
router = APIRouter()
############################
# GetSkills
############################
@router.get("/", response_model=list[SkillUserResponse])
async def get_skills(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
skills = Skills.get_skills(db=db)
else:
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
}
all_skills = Skills.get_skills(db=db)
skills = [
skill
for skill in all_skills
if skill.user_id == user.id
or AccessGrants.has_access(
user_id=user.id,
resource_type="skill",
resource_id=skill.id,
permission="read",
user_group_ids=user_group_ids,
db=db,
)
]
return skills
############################
# GetSkillList
############################
@router.get("/list", response_model=SkillAccessListResponse)
async def get_skill_list(
query: Optional[str] = None,
view_option: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
limit = PAGE_ITEM_COUNT
page = max(1, page)
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if view_option:
filter["view_option"] = view_option
if not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL):
groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups:
filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id
result = Skills.search_skills(user.id, filter=filter, skip=skip, limit=limit, db=db)
return SkillAccessListResponse(
items=[
SkillAccessResponse(
**skill.model_dump(),
write_access=(
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == skill.user_id
or AccessGrants.has_access(
user_id=user.id,
resource_type="skill",
resource_id=skill.id,
permission="write",
db=db,
)
),
)
for skill in result.items
],
total=result.total,
)
############################
# ExportSkills
############################
@router.get("/export", response_model=list[SkillModel])
async def export_skills(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id,
"workspace.skills",
request.app.state.config.USER_PERMISSIONS,
db=db,
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
return Skills.get_skills(db=db)
else:
return Skills.get_skills_by_user_id(user.id, "read", db=db)
############################
# CreateNewSkill
############################
@router.post("/create", response_model=Optional[SkillResponse])
async def create_new_skill(
request: Request,
form_data: SkillForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id, "workspace.skills", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
form_data.id = form_data.id.lower().replace(" ", "-")
existing = Skills.get_skill_by_id(form_data.id, db=db)
if existing is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
try:
skill = Skills.insert_new_skill(user.id, form_data, db=db)
if skill:
return skill
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating skill"),
)
except Exception as e:
log.exception(f"Failed to create skill: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
############################
# GetSkillById
############################
@router.get("/id/{id}", response_model=Optional[SkillAccessResponse])
async def get_skill_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
skill = Skills.get_skill_by_id(id, db=db)
if skill:
if (
user.role == "admin"
or skill.user_id == user.id
or AccessGrants.has_access(
user_id=user.id,
resource_type="skill",
resource_id=skill.id,
permission="read",
db=db,
)
):
return SkillAccessResponse(
**skill.model_dump(),
write_access=(
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == skill.user_id
or AccessGrants.has_access(
user_id=user.id,
resource_type="skill",
resource_id=skill.id,
permission="write",
db=db,
)
),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateSkillById
############################
@router.post("/id/{id}/update", response_model=Optional[SkillModel])
async def update_skill_by_id(
request: Request,
id: str,
form_data: SkillForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
skill = Skills.get_skill_by_id(id, db=db)
if not skill:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
skill.user_id != user.id
and not AccessGrants.has_access(
user_id=user.id,
resource_type="skill",
resource_id=skill.id,
permission="write",
db=db,
)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
try:
updated = {
**form_data.model_dump(exclude={"id"}),
}
skill = Skills.update_skill_by_id(id, updated, db=db)
if skill:
return skill
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating skill"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
############################
# UpdateSkillAccessById
############################
class SkillAccessGrantsForm(BaseModel):
access_grants: list[dict]
@router.post("/id/{id}/access/update", response_model=Optional[SkillModel])
async def update_skill_access_by_id(
request: Request,
id: str,
form_data: SkillAccessGrantsForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
skill = Skills.get_skill_by_id(id, db=db)
if not skill:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
skill.user_id != user.id
and not AccessGrants.has_access(
user_id=user.id,
resource_type="skill",
resource_id=skill.id,
permission="write",
db=db,
)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
form_data.access_grants = filter_allowed_access_grants(
request.app.state.config.USER_PERMISSIONS,
user.id,
user.role,
form_data.access_grants,
"sharing.public_skills"
)
AccessGrants.set_access_grants("skill", id, form_data.access_grants, db=db)
return Skills.get_skill_by_id(id, db=db)
############################
# ToggleSkillById
############################
@router.post("/id/{id}/toggle", response_model=Optional[SkillModel])
async def toggle_skill_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
skill = Skills.get_skill_by_id(id, db=db)
if skill:
if (
user.role == "admin"
or skill.user_id == user.id
or AccessGrants.has_access(
user_id=user.id,
resource_type="skill",
resource_id=skill.id,
permission="write",
db=db,
)
):
skill = Skills.toggle_skill_by_id(id, db=db)
if skill:
return skill
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error toggling skill"),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# DeleteSkillById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_skill_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
skill = Skills.get_skill_by_id(id, db=db)
if not skill:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
skill.user_id != user.id
and not AccessGrants.has_access(
user_id=user.id,
resource_type="skill",
resource_id=skill.id,
permission="write",
db=db,
)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Skills.delete_skill_by_id(id, db=db)
return result