mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
feat: chat auto tag
This commit is contained in:
@@ -108,6 +108,7 @@ class TASKS(str, Enum):
|
||||
|
||||
DEFAULT = lambda task="": f"{task if task else 'generation'}"
|
||||
TITLE_GENERATION = "title_generation"
|
||||
TAGS_GENERATION = "tags_generation"
|
||||
EMOJI_GENERATION = "emoji_generation"
|
||||
QUERY_GENERATION = "query_generation"
|
||||
FUNCTION_CALLING = "function_calling"
|
||||
|
||||
@@ -134,6 +134,7 @@ from open_webui.utils.misc import (
|
||||
)
|
||||
from open_webui.utils.task import (
|
||||
moa_response_generation_template,
|
||||
tags_generation_template,
|
||||
search_query_generation_template,
|
||||
title_generation_template,
|
||||
tools_function_calling_generation_template,
|
||||
@@ -1545,6 +1546,72 @@ Prompt: {{prompt:middletruncate:8000}}"""
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/tags/completions")
|
||||
async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_chat_tags")
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
print(task_model_id)
|
||||
|
||||
template = """### Task:
|
||||
Generate 1-3 broad tags categorizing the main themes of the chat history.
|
||||
|
||||
### Guidelines:
|
||||
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
|
||||
- Only add more specific subdomains if they are strongly represented throughout the conversation
|
||||
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
|
||||
- Use the chat's primary language; default to English if multilingual
|
||||
- Prioritize accuracy over specificity
|
||||
|
||||
### Output:
|
||||
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
||||
|
||||
### Chat History:
|
||||
<chat_history>
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
content = tags_generation_template(
|
||||
template, form_data["messages"], {"name": user.name}
|
||||
)
|
||||
|
||||
print("content", content)
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data},
|
||||
}
|
||||
log.debug(payload)
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/query/completions")
|
||||
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_search_query")
|
||||
|
||||
@@ -123,6 +123,24 @@ def replace_messages_variable(template: str, messages: list[str]) -> str:
|
||||
return template
|
||||
|
||||
|
||||
def tags_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
def search_query_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
||||
Reference in New Issue
Block a user