feat: chat auto tag

This commit is contained in:
Timothy J. Baek
2024-10-19 20:34:17 -07:00
parent 2db0f58dcb
commit d795940ced
7 changed files with 233 additions and 11 deletions

View File

@@ -108,6 +108,7 @@ class TASKS(str, Enum):
DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation"
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
FUNCTION_CALLING = "function_calling"

View File

@@ -134,6 +134,7 @@ from open_webui.utils.misc import (
)
from open_webui.utils.task import (
moa_response_generation_template,
tags_generation_template,
search_query_generation_template,
title_generation_template,
tools_function_calling_generation_template,
@@ -1545,6 +1546,72 @@ Prompt: {{prompt:middletruncate:8000}}"""
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/tags/completions")
async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
print("generate_chat_tags")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(model_id)
print(task_model_id)
template = """### Task:
Generate 1-3 broad tags categorizing the main themes of the chat history.
### Guidelines:
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
- Only add more specific subdomains if they are strongly represented throughout the conversation
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
- Use the chat's primary language; default to English if multilingual
- Prioritize accuracy over specificity
### Output:
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>"""
content = tags_generation_template(
template, form_data["messages"], {"name": user.name}
)
print("content", content)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data},
}
log.debug(payload)
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/query/completions")
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query")

View File

@@ -123,6 +123,24 @@ def replace_messages_variable(template: str, messages: list[str]) -> str:
return template
def tags_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def search_query_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str: