feat: autocomplete backend endpoint

This commit is contained in:
Timothy Jaeryang Baek
2024-11-28 23:53:52 -08:00
parent 28ce102a79
commit 0e8e9820d0
4 changed files with 139 additions and 0 deletions

View File

@@ -89,6 +89,8 @@ from open_webui.config import (
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
@@ -127,6 +129,7 @@ from open_webui.utils.task import (
rag_template,
title_generation_template,
query_generation_template,
autocomplete_generation_template,
tags_generation_template,
emoji_generation_template,
moa_response_generation_template,
@@ -215,6 +218,10 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = (
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
@@ -1982,6 +1989,73 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/auto/completions")
async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)):
context = form_data.get("context")
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating autocompletion using model {task_model_id} for user {user.email}"
)
if (app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
template = app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
content = autocomplete_generation_template(
template, form_data["messages"], context, {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.AUTOCOMPLETION_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/emoji/completions")
async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):