diff --git a/backend/open_webui/models/skills.py b/backend/open_webui/models/skills.py index da3e70011d..00239e6afd 100644 --- a/backend/open_webui/models/skills.py +++ b/backend/open_webui/models/skills.py @@ -96,6 +96,11 @@ class SkillForm(BaseModel): class SkillListResponse(BaseModel): + items: list[SkillUserResponse] = [] + total: int = 0 + + +class SkillAccessListResponse(BaseModel): items: list[SkillAccessResponse] = [] total: int = 0 @@ -208,81 +213,77 @@ class SkillsTable: def search_skills( self, user_id: str, - filter: dict, + filter: dict = {}, skip: int = 0, limit: int = 30, db: Optional[Session] = None, ) -> SkillListResponse: try: with get_db_context(db) as db: - query = db.query(Skill) + from open_webui.models.users import User, UserModel - query_key = filter.get("query") - if query_key: - query = query.filter( - or_( - Skill.name.ilike(f"%{query_key}%"), - Skill.description.ilike(f"%{query_key}%"), - Skill.id.ilike(f"%{query_key}%"), + # Join with User table for user filtering + query = db.query(Skill, User).outerjoin( + User, User.id == Skill.user_id + ) + + if filter: + query_key = filter.get("query") + if query_key: + query = query.filter( + or_( + Skill.name.ilike(f"%{query_key}%"), + Skill.description.ilike(f"%{query_key}%"), + Skill.id.ilike(f"%{query_key}%"), + User.name.ilike(f"%{query_key}%"), + User.email.ilike(f"%{query_key}%"), + ) ) - ) - # Only active skills - query = query.filter(Skill.is_active == True) + view_option = filter.get("view_option") + if view_option == "created": + query = query.filter(Skill.user_id == user_id) + elif view_option == "shared": + query = query.filter(Skill.user_id != user_id) + + # Apply access grant filtering + query = AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Skill, + filter=filter, + resource_type="skill", + permission="read", + ) query = query.order_by(Skill.updated_at.desc()) - # Apply access control if not admin bypass - if "user_id" in filter: - user_group_ids = { - group.id - for group in Groups.get_groups_by_member_id( - filter["user_id"], db=db - ) - } - all_results = query.all() - accessible = [ - s - for s in all_results - if s.user_id == filter["user_id"] - or AccessGrants.has_access( - user_id=filter["user_id"], - resource_type="skill", - resource_id=s.id, - permission="read", - user_group_ids=user_group_ids, - db=db, - ) - ] - total = len(accessible) - items = accessible[skip : skip + limit] if limit else accessible[skip:] - else: - total = query.count() - if skip: - query = query.offset(skip) - if limit: - query = query.limit(limit) - items = query.all() + # Count BEFORE pagination + total = query.count() - user_ids = list(set(s.user_id for s in items)) - users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] - users_dict = {u.id: u for u in users} + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) - skill_responses = [] - for skill in items: - user = users_dict.get(skill.user_id) - skill_model = self._to_skill_model(skill, db=db) - skill_responses.append( - SkillAccessResponse( - **SkillUserResponse( - **skill_model.model_dump(), - user=user.model_dump() if user else None, - ).model_dump(), - write_access=False, + items = query.all() + + skills = [] + for skill, user in items: + skills.append( + SkillUserResponse( + **self._to_skill_model(skill, db=db).model_dump(), + user=( + UserResponse( + **UserModel.model_validate(user).model_dump() + ) + if user + else None + ), ) ) - return SkillListResponse(items=skill_responses, total=total) + return SkillListResponse(items=skills, total=total) except Exception as e: log.exception(f"Error searching skills: {e}") return SkillListResponse(items=[], total=0) diff --git a/backend/open_webui/routers/skills.py b/backend/open_webui/routers/skills.py index e15226766b..04ba56caee 100644 --- a/backend/open_webui/routers/skills.py +++ b/backend/open_webui/routers/skills.py @@ -14,7 +14,7 @@ from open_webui.models.skills import ( SkillResponse, SkillUserResponse, SkillAccessResponse, - SkillListResponse, + SkillAccessListResponse, Skills, ) from open_webui.models.access_grants import AccessGrants @@ -72,62 +72,56 @@ async def get_skills( ############################ -@router.get("/list", response_model=list[SkillAccessResponse]) +@router.get("/list", response_model=SkillAccessListResponse) async def get_skill_list( - user=Depends(get_verified_user), db: Session = Depends(get_session) -): - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - skills = Skills.get_skills(db=db) - else: - skills = Skills.get_skills_by_user_id(user.id, "read", db=db) - - return [ - SkillAccessResponse( - **skill.model_dump(), - write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or user.id == skill.user_id - or AccessGrants.has_access( - user_id=user.id, - resource_type="skill", - resource_id=skill.id, - permission="write", - db=db, - ) - ), - ) - for skill in skills - ] - - -############################ -# SearchSkills -############################ - - -@router.get("/search", response_model=SkillListResponse) -async def search_skills( query: Optional[str] = None, + view_option: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - page = max(page, 1) limit = PAGE_ITEM_COUNT + + page = max(1, page) skip = (page - 1) * limit filter = {} if query: filter["query"] = query + if view_option: + filter["view_option"] = view_option if not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL): + groups = Groups.get_groups_by_member_id(user.id, db=db) + if groups: + filter["group_ids"] = [group.id for group in groups] + filter["user_id"] = user.id result = Skills.search_skills( user.id, filter=filter, skip=skip, limit=limit, db=db ) - return result + return SkillAccessListResponse( + items=[ + SkillAccessResponse( + **skill.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == skill.user_id + or AccessGrants.has_access( + user_id=user.id, + resource_type="skill", + resource_id=skill.id, + permission="write", + db=db, + ) + ), + ) + for skill in result.items + ], + total=result.total, + ) ############################ diff --git a/src/lib/apis/skills/index.ts b/src/lib/apis/skills/index.ts index 8477b4a6f5..14543046b9 100644 --- a/src/lib/apis/skills/index.ts +++ b/src/lib/apis/skills/index.ts @@ -93,18 +93,20 @@ export const getSkillList = async (token: string = '') => { return res; }; -export const searchSkills = async ( +export const getSkillItems = async ( token: string = '', query: string | null = null, + viewOption: string | null = null, page: number | null = null ) => { let error = null; const searchParams = new URLSearchParams(); if (query) searchParams.append('query', query); + if (viewOption) searchParams.append('view_option', viewOption); if (page) searchParams.append('page', page.toString()); - const res = await fetch(`${WEBUI_API_BASE_URL}/skills/search?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/list?${searchParams.toString()}`, { method: 'GET', headers: { Accept: 'application/json', @@ -120,7 +122,7 @@ export const searchSkills = async ( return json; }) .catch((err) => { - error = err.detail; + error = err; console.error(err); return null; }); diff --git a/src/lib/components/chat/MessageInput/Commands/Skills.svelte b/src/lib/components/chat/MessageInput/Commands/Skills.svelte index 9454b950a0..b998cc00ce 100644 --- a/src/lib/components/chat/MessageInput/Commands/Skills.svelte +++ b/src/lib/components/chat/MessageInput/Commands/Skills.svelte @@ -1,6 +1,6 @@