enh: image prompt enhancer

This commit is contained in:
Timothy Jaeryang Baek
2025-01-16 00:06:37 -08:00
parent d3a5b9c127
commit 0360aa5520
7 changed files with 166 additions and 1 deletions

View File

@@ -9,6 +9,7 @@ from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
title_generation_template,
query_generation_template,
image_prompt_generation_template,
autocomplete_generation_template,
tags_generation_template,
emoji_generation_template,
@@ -23,6 +24,7 @@ from open_webui.utils.task import get_task_model_id
from open_webui.config import (
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
@@ -50,6 +52,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
@@ -65,6 +68,7 @@ class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
@@ -114,6 +118,7 @@ async def update_task_config(
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
@@ -256,6 +261,66 @@ async def generate_chat_tags(
)
@router.post("/image_prompt/completions")
async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating image prompt using model {task_model_id} for user {user.email} "
)
if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
content = image_prompt_generation_template(
template,
form_data["messages"],
user={
"name": user.name,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error("Exception occurred", exc_info=True)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "An internal error has occurred."},
)
@router.post("/queries/completions")
async def generate_queries(
request: Request, form_data: dict, user=Depends(get_verified_user)