mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
enh: retrieval query generation
This commit is contained in:
@@ -78,11 +78,13 @@ from open_webui.config import (
|
||||
ENV,
|
||||
FRONTEND_BUILD_DIR,
|
||||
OAUTH_PROVIDERS,
|
||||
ENABLE_SEARCH_QUERY,
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
STATIC_DIR,
|
||||
TASK_MODEL,
|
||||
TASK_MODEL_EXTERNAL,
|
||||
ENABLE_SEARCH_QUERY_GENERATION,
|
||||
ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
@@ -122,7 +124,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
from open_webui.utils.task import (
|
||||
moa_response_generation_template,
|
||||
tags_generation_template,
|
||||
search_query_generation_template,
|
||||
query_generation_template,
|
||||
emoji_generation_template,
|
||||
title_generation_template,
|
||||
tools_function_calling_generation_template,
|
||||
@@ -206,10 +208,9 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
||||
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
|
||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
|
||||
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
@@ -492,14 +493,41 @@ async def chat_completion_tools_handler(
|
||||
return body, {"contexts": contexts, "citations": citations}
|
||||
|
||||
|
||||
async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
|
||||
async def chat_completion_files_handler(
|
||||
body: dict, user: UserModel
|
||||
) -> tuple[dict, dict[str, list]]:
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
try:
|
||||
queries_response = await generate_queries(
|
||||
{
|
||||
"model": body["model"],
|
||||
"messages": body["messages"],
|
||||
"type": "retrieval",
|
||||
},
|
||||
user,
|
||||
)
|
||||
queries_response = queries_response["choices"][0]["message"]["content"]
|
||||
|
||||
try:
|
||||
queries_response = json.loads(queries_response)
|
||||
except Exception as e:
|
||||
queries_response = {"queries": []}
|
||||
|
||||
queries = queries_response.get("queries", [])
|
||||
except Exception as e:
|
||||
queries = []
|
||||
|
||||
if len(queries) == 0:
|
||||
queries = [get_last_user_message(body["messages"])]
|
||||
|
||||
print(f"{queries=}")
|
||||
|
||||
if files := body.get("metadata", {}).get("files", None):
|
||||
contexts, citations = get_rag_context(
|
||||
files=files,
|
||||
messages=body["messages"],
|
||||
queries=queries,
|
||||
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
|
||||
k=retrieval_app.state.config.TOP_K,
|
||||
reranking_function=retrieval_app.state.sentence_transformer_rf,
|
||||
@@ -643,7 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
log.exception(e)
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_files_handler(body)
|
||||
body, flags = await chat_completion_files_handler(body, user)
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
except Exception as e:
|
||||
@@ -1579,8 +1607,9 @@ async def get_task_config(user=Depends(get_verified_user)):
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
@@ -1591,8 +1620,9 @@ class TaskConfigForm(BaseModel):
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_TAGS_GENERATION: bool
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_SEARCH_QUERY: bool
|
||||
ENABLE_SEARCH_QUERY_GENERATION: bool
|
||||
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
||||
QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
||||
|
||||
|
||||
@@ -1607,11 +1637,16 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
||||
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
|
||||
|
||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
|
||||
form_data.ENABLE_SEARCH_QUERY_GENERATION
|
||||
)
|
||||
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
|
||||
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||
)
|
||||
|
||||
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
)
|
||||
@@ -1622,8 +1657,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
@@ -1799,14 +1835,22 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/query/completions")
|
||||
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_search_query")
|
||||
if not app.state.config.ENABLE_SEARCH_QUERY:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Search query generation is disabled",
|
||||
)
|
||||
@app.post("/api/task/queries/completions")
|
||||
async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_queries")
|
||||
type = form_data.get("type")
|
||||
if type == "web_search":
|
||||
if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Search query generation is disabled",
|
||||
)
|
||||
elif type == "retrieval":
|
||||
if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Query generation is disabled",
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
@@ -1830,20 +1874,12 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
|
||||
model = models[task_model_id]
|
||||
|
||||
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}.
|
||||
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
User Message:
|
||||
{{prompt:end:4000}}
|
||||
|
||||
Interaction History:
|
||||
{{MESSAGES:END:6}}
|
||||
|
||||
Search Query:"""
|
||||
|
||||
content = search_query_generation_template(
|
||||
content = query_generation_template(
|
||||
template, form_data["messages"], {"name": user.name}
|
||||
)
|
||||
|
||||
@@ -1851,13 +1887,6 @@ Search Query:"""
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 30}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 30,
|
||||
}
|
||||
),
|
||||
"metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data},
|
||||
}
|
||||
log.debug(payload)
|
||||
|
||||
Reference in New Issue
Block a user