diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 00270aabc4..a4d63a6d7c 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -875,6 +875,7 @@ async def chat_completion( "tool_ids": form_data.get("tool_ids", None), "files": form_data.get("files", None), "features": form_data.get("features", None), + "variables": form_data.get("variables", None), } form_data["metadata"] = metadata diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index d9124c29f7..780fc6f50d 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -977,6 +977,7 @@ async def generate_chat_completion( if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True + metadata = form_data.pop("metadata", None) try: form_data = GenerateChatCompletionForm(**form_data) except Exception as e: @@ -987,8 +988,6 @@ async def generate_chat_completion( ) payload = {**form_data.model_dump(exclude_none=True)} - if "metadata" in payload: - del payload["metadata"] model_id = payload["model"] model_info = Models.get_model_by_id(model_id) @@ -1006,7 +1005,7 @@ async def generate_chat_completion( payload["options"] = apply_model_params_to_body_ollama( params, payload["options"] ) - payload = apply_model_system_prompt_to_body(params, payload, user) + payload = apply_model_system_prompt_to_body(params, payload, metadata) # Check if user has access to the model if not bypass_filter and user.role == "user": diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 2139be4ef5..c27f35e7e7 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -551,9 +551,9 @@ async def generate_chat_completion( bypass_filter = True idx = 0 + payload = {**form_data} - if "metadata" in payload: - del payload["metadata"] + metadata = payload.pop("metadata", None) model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) @@ -566,7 +566,7 @@ async def generate_chat_completion( params = model_info.params.model_dump() payload = apply_model_params_to_body_openai(params, payload) - payload = apply_model_system_prompt_to_body(params, payload, user) + payload = apply_model_system_prompt_to_body(params, payload, metadata) # Check if user has access to the model if not bypass_filter and user.role == "user": diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 833db503a5..b821e11a52 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -749,6 +749,8 @@ async def process_chat_payload(request, form_data, metadata, user, model): files.extend(knowledge_files) form_data["files"] = files + variables = form_data.pop("variables", None) + features = form_data.pop("features", None) if features: if "web_search" in features and features["web_search"]: diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 13f98ee019..2eb4622c2b 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -1,4 +1,4 @@ -from open_webui.utils.task import prompt_template +from open_webui.utils.task import prompt_variables_template from open_webui.utils.misc import ( add_or_update_system_message, ) @@ -7,19 +7,18 @@ from typing import Callable, Optional # inplace function: form_data is modified -def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict: +def apply_model_system_prompt_to_body( + params: dict, form_data: dict, metadata: Optional[dict] = None +) -> dict: system = params.get("system", None) if not system: return form_data - if user: - template_params = { - "user_name": user.name, - "user_location": user.info.get("location") if user.info else None, - } - else: - template_params = {} - system = prompt_template(system, **template_params) + if metadata: + print("apply_model_system_prompt_to_body: metadata", metadata) + variables = metadata.get("variables", {}) + system = prompt_variables_template(system, variables) + form_data["messages"] = add_or_update_system_message( system, form_data.get("messages", []) ) @@ -188,4 +187,7 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: if ollama_options: ollama_payload["options"] = ollama_options + if "metadata" in openai_payload: + ollama_payload["metadata"] = openai_payload["metadata"] + return ollama_payload diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index f5ba75ebec..3d8c05d455 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -32,6 +32,12 @@ def get_task_model_id( return task_model_id +def prompt_variables_template(template: str, variables: dict[str, str]) -> str: + for variable, value in variables.items(): + template = template.replace(variable, value) + return template + + def prompt_template( template: str, user_name: Optional[str] = None, user_location: Optional[str] = None ) -> str: diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 24abe61109..3cfb618806 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -45,7 +45,8 @@ promptTemplate, splitStream, sleep, - removeDetailsWithReasoning + removeDetailsWithReasoning, + getPromptVariables } from '$lib/utils'; import { generateChatCompletion } from '$lib/apis/ollama'; @@ -628,7 +629,7 @@ } catch (e) { // Remove the failed doc from the files array files = files.filter((f) => f.name !== url); - toast.error(e); + toast.error(`${e}`); } }; @@ -1558,10 +1559,17 @@ files: (files?.length ?? 0) > 0 ? files : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, + features: { image_generation: imageGenerationEnabled, web_search: webSearchEnabled }, + variables: { + ...getPromptVariables( + $user.name, + $settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined + ) + }, session_id: $socket?.id, chat_id: $chatId, diff --git a/src/lib/utils/index.ts b/src/lib/utils/index.ts index 20f44f49bb..2ccc1bf5da 100644 --- a/src/lib/utils/index.ts +++ b/src/lib/utils/index.ts @@ -766,6 +766,19 @@ export const blobToFile = (blob, fileName) => { return file; }; +export const getPromptVariables = (user_name, user_location) => { + return { + '{{USER_NAME}}': user_name, + '{{USER_LOCATION}}': user_location || 'Unknown', + '{{CURRENT_DATETIME}}': getCurrentDateTime(), + '{{CURRENT_DATE}}': getFormattedDate(), + '{{CURRENT_TIME}}': getFormattedTime(), + '{{CURRENT_WEEKDAY}}': getWeekday(), + '{{CURRENT_TIMEZONE}}': getUserTimezone(), + '{{USER_LANGUAGE}}': localStorage.getItem('locale') || 'en-US' + }; +}; + /** * @param {string} template - The template string containing placeholders. * @returns {string} The template string with the placeholders replaced by the prompt.