diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 381d4109fd..4a85ba9029 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -10,7 +10,8 @@ from open_webui.models.prompt_history import PromptHistories from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON +from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, or_, func, cast + from open_webui.utils.access_control import has_access @@ -85,7 +86,18 @@ class PromptAccessResponse(PromptUserResponse): write_access: Optional[bool] = False +class PromptListResponse(BaseModel): + items: list[PromptUserResponse] + total: int + + +class PromptAccessListResponse(BaseModel): + items: list[PromptAccessResponse] + total: int + + class PromptForm(BaseModel): + command: str name: str # Changed from title content: str @@ -227,7 +239,109 @@ class PromptsTable: or has_access(user_id, permission, prompt.access_control, user_group_ids) ] + def search_prompts( + self, + user_id: str, + filter: dict = {}, + skip: int = 0, + limit: int = 30, + db: Optional[Session] = None, + ) -> PromptListResponse: + with get_db_context(db) as db: + from open_webui.models.users import User, UserModel + + # Join with User table for user filtering and sorting + query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id) + query = query.filter(Prompt.is_active == True) + + if filter: + query_key = filter.get("query") + if query_key: + query = query.filter( + or_( + Prompt.name.ilike(f"%{query_key}%"), + Prompt.command.ilike(f"%{query_key}%"), + Prompt.content.ilike(f"%{query_key}%"), + User.name.ilike(f"%{query_key}%"), + User.email.ilike(f"%{query_key}%"), + ) + ) + + view_option = filter.get("view_option") + if view_option == "created": + query = query.filter(Prompt.user_id == user_id) + elif view_option == "shared": + query = query.filter(Prompt.user_id != user_id) + + # Apply access control filtering + group_ids = filter.get("group_ids", []) + filter_user_id = filter.get("user_id") + + if filter_user_id: + # User must have access: owner OR public OR explicit access + access_conditions = [ + Prompt.user_id == filter_user_id, # Owner + Prompt.access_control == None, # Public + ] + query = query.filter(or_(*access_conditions)) + + tag = filter.get("tag") + if tag: + # Search for tag in JSON array field + like_pattern = f'%"{tag.lower()}"%' + tags_text = func.lower(cast(Prompt.tags, String)) + query = query.filter(tags_text.like(like_pattern)) + + order_by = filter.get("order_by") + direction = filter.get("direction") + + if order_by == "name": + if direction == "asc": + query = query.order_by(Prompt.name.asc()) + else: + query = query.order_by(Prompt.name.desc()) + elif order_by == "created_at": + if direction == "asc": + query = query.order_by(Prompt.created_at.asc()) + else: + query = query.order_by(Prompt.created_at.desc()) + elif order_by == "updated_at": + if direction == "asc": + query = query.order_by(Prompt.updated_at.asc()) + else: + query = query.order_by(Prompt.updated_at.desc()) + else: + query = query.order_by(Prompt.updated_at.desc()) + else: + query = query.order_by(Prompt.updated_at.desc()) + + # Count BEFORE pagination + total = query.count() + + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + + items = query.all() + + prompts = [] + for prompt, user in items: + prompts.append( + PromptUserResponse( + **PromptModel.model_validate(prompt).model_dump(), + user=( + UserResponse(**UserModel.model_validate(user).model_dump()) + if user + else None + ), + ) + ) + + return PromptListResponse(items=prompts, total=total) + def update_prompt_by_command( + self, command: str, form_data: PromptForm, diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 3515009175..fc24ccaf4b 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -5,9 +5,11 @@ from open_webui.models.prompts import ( PromptForm, PromptUserResponse, PromptAccessResponse, + PromptAccessListResponse, PromptModel, Prompts, ) +from open_webui.models.groups import Groups from open_webui.models.prompt_history import ( PromptHistories, PromptHistoryModel, @@ -34,6 +36,8 @@ class PromptMetadataForm(BaseModel): router = APIRouter() +PAGE_ITEM_COUNT = 30 + ############################ # GetPrompts @@ -67,26 +71,57 @@ async def get_prompt_tags( return sorted(list(tags)) -@router.get("/list", response_model=list[PromptAccessResponse]) +@router.get("/list", response_model=PromptAccessListResponse) async def get_prompt_list( - user=Depends(get_verified_user), db: Session = Depends(get_session) + query: Optional[str] = None, + view_option: Optional[str] = None, + tag: Optional[str] = None, + order_by: Optional[str] = None, + direction: Optional[str] = None, + page: Optional[int] = 1, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - prompts = Prompts.get_prompts(db=db) - else: - prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db) + limit = PAGE_ITEM_COUNT - return [ - PromptAccessResponse( - **prompt.model_dump(), - write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or user.id == prompt.user_id - or has_access(user.id, "write", prompt.access_control, db=db) - ), - ) - for prompt in prompts - ] + page = max(1, page) + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + if view_option: + filter["view_option"] = view_option + if tag: + filter["tag"] = tag + if order_by: + filter["order_by"] = order_by + if direction: + filter["direction"] = direction + + if not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL): + groups = Groups.get_groups_by_member_id(user.id, db=db) + if groups: + filter["group_ids"] = [group.id for group in groups] + + filter["user_id"] = user.id + + result = Prompts.search_prompts(user.id, filter=filter, skip=skip, limit=limit, db=db) + + return PromptAccessListResponse( + items=[ + PromptAccessResponse( + **prompt.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == prompt.user_id + or has_access(user.id, "write", prompt.access_control, db=db) + ), + ) + for prompt in result.items + ], + total=result.total, + ) ############################ diff --git a/src/lib/apis/prompts/index.ts b/src/lib/apis/prompts/index.ts index 58ab889f80..e9cd6e8481 100644 --- a/src/lib/apis/prompts/index.ts +++ b/src/lib/apis/prompts/index.ts @@ -136,7 +136,67 @@ export const getPromptTags = async (token: string = '') => { return res; }; +export const getPromptItems = async ( + token: string = '', + query: string | null, + viewOption: string | null, + selectedTag: string | null, + orderBy: string | null, + direction: string | null, + page: number +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (query) { + searchParams.append('query', query); + } + if (viewOption) { + searchParams.append('view_option', viewOption); + } + if (selectedTag) { + searchParams.append('tag', selectedTag); + } + if (orderBy) { + searchParams.append('order_by', orderBy); + } + if (direction) { + searchParams.append('direction', direction); + } + if (page) { + searchParams.append('page', page.toString()); + } + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/list?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getPromptList = async (token: string = '') => { + let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/list`, { diff --git a/src/lib/components/workspace/Prompts.svelte b/src/lib/components/workspace/Prompts.svelte index 3d9a9acdb4..3c2e333b95 100644 --- a/src/lib/components/workspace/Prompts.svelte +++ b/src/lib/components/workspace/Prompts.svelte @@ -11,7 +11,7 @@ createNewPrompt, deletePromptById, getPrompts, - getPromptList, + getPromptItems, getPromptTags } from '$lib/apis/prompts'; import { capitalizeFirstLetter, slugify, copyToClipboard } from '$lib/utils'; @@ -31,6 +31,7 @@ import ViewSelector from './common/ViewSelector.svelte'; import TagSelector from './common/TagSelector.svelte'; import Badge from '$lib/components/common/Badge.svelte'; + import Pagination from '../common/Pagination.svelte'; let shiftKey = false; const i18n = getContext('i18n'); @@ -40,8 +41,10 @@ let importFiles = null; let query = ''; - let prompts = []; + let prompts = null; let tags = []; + let total = null; + let loading = false; let showDeleteConfirm = false; let deletePrompt = null; @@ -51,27 +54,54 @@ let selectedTag = ''; let copiedId: string | null = null; - let filteredItems = []; + let page = 1; + let searchDebounceTimer; - $: if (prompts && query !== undefined && viewOption !== undefined && selectedTag !== undefined) { - setFilteredItems(); + // Debounce only query changes + $: if (query !== undefined) { + loading = true; + clearTimeout(searchDebounceTimer); + searchDebounceTimer = setTimeout(() => { + getPromptList(); + }, 300); } - const setFilteredItems = () => { - filteredItems = prompts.filter((p) => { - if (query === '' && viewOption === '' && selectedTag === '') return true; - const lowerQuery = query.toLowerCase(); - return ( - ((p.title || '').toLowerCase().includes(lowerQuery) || - (p.command || '').toLowerCase().includes(lowerQuery) || - (p.user?.name || '').toLowerCase().includes(lowerQuery) || - (p.user?.email || '').toLowerCase().includes(lowerQuery)) && - (viewOption === '' || - (viewOption === 'created' && p.user_id === $user?.id) || - (viewOption === 'shared' && p.user_id !== $user?.id)) && - (selectedTag === '' || (p.tags && p.tags.includes(selectedTag))) - ); - }); + // Immediate response to page/filter changes + $: if (page && selectedTag !== undefined && viewOption !== undefined) { + getPromptList(); + } + + const getPromptList = async () => { + loading = true; + try { + const res = await getPromptItems( + localStorage.token, + query, + viewOption, + selectedTag, + null, + null, + page + ).catch((error) => { + toast.error(`${error}`); + return null; + }); + + if (res) { + prompts = res.items; + total = res.total; + + // get tags + tags = await getPromptTags(localStorage.token).catch((error) => { + toast.error(`${error}`); + return []; + }); + } + } catch (err) { + console.error(err); + } finally { + loading = false; + } }; const shareHandler = async (prompt) => { @@ -134,18 +164,14 @@ toast.success($i18n.t(`Deleted {{name}}`, { name: command })); } - await init(); - }; - - const init = async () => { - prompts = await getPromptList(localStorage.token); - tags = await getPromptTags(localStorage.token); + page = 1; + getPromptList(); await _prompts.set(await getPrompts(localStorage.token)); }; onMount(async () => { viewOption = localStorage?.workspaceViewOption || ''; - await init(); + page = 1; loaded = true; const onKeyDown = (event) => { @@ -222,7 +248,9 @@ }); } - prompts = await getPromptList(localStorage.token); + prompts = null; + page = 1; + getPromptList(); await _prompts.set(await getPrompts(localStorage.token)); importFiles = []; @@ -239,7 +267,7 @@