mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
wip: prompts
This commit is contained in:
@@ -69,7 +69,7 @@ class PromptForm(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class PromptsTable:
|
class PromptsTable:
|
||||||
def insert_new_prompt(
|
async def insert_new_prompt(
|
||||||
self, user_id: str, form_data: PromptForm
|
self, user_id: str, form_data: PromptForm
|
||||||
) -> Optional[PromptModel]:
|
) -> Optional[PromptModel]:
|
||||||
prompt = PromptModel(
|
prompt = PromptModel(
|
||||||
@@ -83,9 +83,9 @@ class PromptsTable:
|
|||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
result = Prompt(**prompt.model_dump())
|
result = Prompt(**prompt.model_dump())
|
||||||
db.add(result)
|
await db.add(result)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(result)
|
await db.refresh(result)
|
||||||
if result:
|
if result:
|
||||||
return PromptModel.model_validate(result)
|
return PromptModel.model_validate(result)
|
||||||
else:
|
else:
|
||||||
@@ -93,10 +93,10 @@ class PromptsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
|
async def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
prompt = await db.query(Prompt).filter_by(command=command).first()
|
||||||
return PromptModel.model_validate(prompt)
|
return PromptModel.model_validate(prompt)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
@@ -105,7 +105,9 @@ class PromptsTable:
|
|||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
prompts = []
|
prompts = []
|
||||||
|
|
||||||
for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
|
for prompt in (
|
||||||
|
await db.query(Prompt).order_by(Prompt.timestamp.desc()).all()
|
||||||
|
):
|
||||||
user = await Users.get_user_by_id(prompt.user_id)
|
user = await Users.get_user_by_id(prompt.user_id)
|
||||||
prompts.append(
|
prompts.append(
|
||||||
PromptUserResponse.model_validate(
|
PromptUserResponse.model_validate(
|
||||||
@@ -118,10 +120,10 @@ class PromptsTable:
|
|||||||
|
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def get_prompts_by_user_id(
|
async def get_prompts_by_user_id(
|
||||||
self, user_id: str, permission: str = "write"
|
self, user_id: str, permission: str = "write"
|
||||||
) -> list[PromptUserResponse]:
|
) -> list[PromptUserResponse]:
|
||||||
prompts = self.get_prompts()
|
prompts = await self.get_prompts()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
prompt
|
prompt
|
||||||
@@ -130,26 +132,26 @@ class PromptsTable:
|
|||||||
or await has_access(user_id, permission, prompt.access_control)
|
or await has_access(user_id, permission, prompt.access_control)
|
||||||
]
|
]
|
||||||
|
|
||||||
def update_prompt_by_command(
|
async def update_prompt_by_command(
|
||||||
self, command: str, form_data: PromptForm
|
self, command: str, form_data: PromptForm
|
||||||
) -> Optional[PromptModel]:
|
) -> Optional[PromptModel]:
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
prompt = await db.query(Prompt).filter_by(command=command).first()
|
||||||
prompt.title = form_data.title
|
prompt.title = form_data.title
|
||||||
prompt.content = form_data.content
|
prompt.content = form_data.content
|
||||||
prompt.access_control = form_data.access_control
|
prompt.access_control = form_data.access_control
|
||||||
prompt.timestamp = int(time.time())
|
prompt.timestamp = int(time.time())
|
||||||
db.commit()
|
await db.commit()
|
||||||
return PromptModel.model_validate(prompt)
|
return PromptModel.model_validate(prompt)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_prompt_by_command(self, command: str) -> bool:
|
async def delete_prompt_by_command(self, command: str) -> bool:
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
db.query(Prompt).filter_by(command=command).delete()
|
await db.query(Prompt).filter_by(command=command).delete()
|
||||||
db.commit()
|
await db.commit()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ router = APIRouter()
|
|||||||
@router.get("/", response_model=list[PromptModel])
|
@router.get("/", response_model=list[PromptModel])
|
||||||
async def get_prompts(user=Depends(get_verified_user)):
|
async def get_prompts(user=Depends(get_verified_user)):
|
||||||
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||||
prompts = Prompts.get_prompts()
|
prompts = await Prompts.get_prompts()
|
||||||
else:
|
else:
|
||||||
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
|
prompts = await Prompts.get_prompts_by_user_id(user.id, "read")
|
||||||
|
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
@@ -32,9 +32,9 @@ async def get_prompts(user=Depends(get_verified_user)):
|
|||||||
@router.get("/list", response_model=list[PromptUserResponse])
|
@router.get("/list", response_model=list[PromptUserResponse])
|
||||||
async def get_prompt_list(user=Depends(get_verified_user)):
|
async def get_prompt_list(user=Depends(get_verified_user)):
|
||||||
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||||
prompts = Prompts.get_prompts()
|
prompts = await Prompts.get_prompts()
|
||||||
else:
|
else:
|
||||||
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
|
prompts = await Prompts.get_prompts_by_user_id(user.id, "write")
|
||||||
|
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
@@ -56,9 +56,9 @@ async def create_new_prompt(
|
|||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = Prompts.get_prompt_by_command(form_data.command)
|
prompt = await Prompts.get_prompt_by_command(form_data.command)
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
prompt = await Prompts.insert_new_prompt(user.id, form_data)
|
||||||
|
|
||||||
if prompt:
|
if prompt:
|
||||||
return prompt
|
return prompt
|
||||||
@@ -141,7 +141,7 @@ async def update_prompt_by_command(
|
|||||||
|
|
||||||
@router.delete("/command/{command}/delete", response_model=bool)
|
@router.delete("/command/{command}/delete", response_model=bool)
|
||||||
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
||||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
prompt = await Prompts.get_prompt_by_command(f"/{command}")
|
||||||
if not prompt:
|
if not prompt:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -158,5 +158,5 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
|
|||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = Prompts.delete_prompt_by_command(f"/{command}")
|
result = await Prompts.delete_prompt_by_command(f"/{command}")
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ def get_permissions(
|
|||||||
return permissions
|
return permissions
|
||||||
|
|
||||||
|
|
||||||
def has_permission(
|
async def has_permission(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
permission_key: str,
|
permission_key: str,
|
||||||
default_permissions: Dict[str, Any] = {},
|
default_permissions: Dict[str, Any] = {},
|
||||||
@@ -93,7 +93,7 @@ def has_permission(
|
|||||||
permission_hierarchy = permission_key.split(".")
|
permission_hierarchy = permission_key.split(".")
|
||||||
|
|
||||||
# Retrieve user group permissions
|
# Retrieve user group permissions
|
||||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
user_groups = await Groups.get_groups_by_member_id(user_id)
|
||||||
|
|
||||||
for group in user_groups:
|
for group in user_groups:
|
||||||
group_permissions = group.permissions
|
group_permissions = group.permissions
|
||||||
|
|||||||
Reference in New Issue
Block a user