mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
feat: search query threshold
This commit is contained in:
@@ -618,6 +618,11 @@ ADMIN_EMAIL = PersistentConfig(
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# TASKS
|
||||
####################################
|
||||
|
||||
|
||||
TASK_MODEL = PersistentConfig(
|
||||
"TASK_MODEL",
|
||||
"task.model.default",
|
||||
@@ -664,6 +669,15 @@ Question:
|
||||
)
|
||||
|
||||
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
|
||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
|
||||
"task.search.prompt_length_threshold",
|
||||
os.environ.get(
|
||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
|
||||
100,
|
||||
),
|
||||
)
|
||||
|
||||
####################################
|
||||
# WEBUI_SECRET_KEY
|
||||
####################################
|
||||
|
||||
@@ -81,6 +81,7 @@ from config import (
|
||||
TASK_MODEL_EXTERNAL,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||
AppConfig,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
@@ -144,6 +145,9 @@ app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMP
|
||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
||||
)
|
||||
|
||||
app.state.MODELS = {}
|
||||
|
||||
@@ -596,6 +600,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_search_query")
|
||||
|
||||
if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
|
||||
)
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
|
||||
Reference in New Issue
Block a user