enh: retrieval query generation

This commit is contained in:
Timothy Jaeryang Baek
2024-11-19 02:24:32 -08:00
parent 09c6e4b92f
commit dbb67a12ca
7 changed files with 217 additions and 138 deletions

View File

@@ -78,11 +78,13 @@ from open_webui.config import (
ENV,
FRONTEND_BUILD_DIR,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
ENABLE_SEARCH_QUERY_GENERATION,
ENABLE_RETRIEVAL_QUERY_GENERATION,
QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@@ -122,7 +124,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.task import (
moa_response_generation_template,
tags_generation_template,
search_query_generation_template,
query_generation_template,
emoji_generation_template,
title_generation_template,
tools_function_calling_generation_template,
@@ -206,10 +208,9 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
@@ -492,14 +493,41 @@ async def chat_completion_tools_handler(
return body, {"contexts": contexts, "citations": citations}
async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
async def chat_completion_files_handler(
body: dict, user: UserModel
) -> tuple[dict, dict[str, list]]:
contexts = []
citations = []
try:
queries_response = await generate_queries(
{
"model": body["model"],
"messages": body["messages"],
"type": "retrieval",
},
user,
)
queries_response = queries_response["choices"][0]["message"]["content"]
try:
queries_response = json.loads(queries_response)
except Exception as e:
queries_response = {"queries": []}
queries = queries_response.get("queries", [])
except Exception as e:
queries = []
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
print(f"{queries=}")
if files := body.get("metadata", {}).get("files", None):
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
queries=queries,
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
k=retrieval_app.state.config.TOP_K,
reranking_function=retrieval_app.state.sentence_transformer_rf,
@@ -643,7 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
log.exception(e)
try:
body, flags = await chat_completion_files_handler(body)
body, flags = await chat_completion_files_handler(body, user)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
@@ -1579,8 +1607,9 @@ async def get_task_config(user=Depends(get_verified_user)):
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@@ -1591,8 +1620,9 @@ class TaskConfigForm(BaseModel):
TITLE_GENERATION_PROMPT_TEMPLATE: str
TAGS_GENERATION_PROMPT_TEMPLATE: str
ENABLE_TAGS_GENERATION: bool
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
ENABLE_SEARCH_QUERY: bool
ENABLE_SEARCH_QUERY_GENERATION: bool
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
QUERY_GENERATION_PROMPT_TEMPLATE: str
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@@ -1607,11 +1637,16 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
)
app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
form_data.ENABLE_SEARCH_QUERY_GENERATION
)
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
)
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
@@ -1622,8 +1657,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
"ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@@ -1799,14 +1835,22 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/query/completions")
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query")
if not app.state.config.ENABLE_SEARCH_QUERY:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Search query generation is disabled",
)
@app.post("/api/task/queries/completions")
async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
print("generate_queries")
type = form_data.get("type")
if type == "web_search":
if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Search query generation is disabled",
)
elif type == "retrieval":
if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Query generation is disabled",
)
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
@@ -1830,20 +1874,12 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
model = models[task_model_id]
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
else:
template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}.
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
User Message:
{{prompt:end:4000}}
Interaction History:
{{MESSAGES:END:6}}
Search Query:"""
content = search_query_generation_template(
content = query_generation_template(
template, form_data["messages"], {"name": user.name}
)
@@ -1851,13 +1887,6 @@ Search Query:"""
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 30}
if models[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 30,
}
),
"metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data},
}
log.debug(payload)