diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index ce3ec4f562..1c59a7b429 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -509,8 +509,8 @@ from open_webui.utils.models import ( from open_webui.utils.chat import ( generate_chat_completion as chat_completion_handler, chat_completed as chat_completed_handler, - chat_action as chat_action_handler, ) +from open_webui.utils.actions import chat_action as chat_action_handler from open_webui.utils.embeddings import generate_embeddings from open_webui.utils.middleware import ( build_chat_response_context, diff --git a/backend/open_webui/utils/actions.py b/backend/open_webui/utils/actions.py new file mode 100644 index 0000000000..0b4b817f0a --- /dev/null +++ b/backend/open_webui/utils/actions.py @@ -0,0 +1,139 @@ +import logging +import sys +import inspect + +from typing import Any + +from fastapi import Request + +from open_webui.models.users import UserModel +from open_webui.models.functions import Functions + +from open_webui.socket.main import get_event_call, get_event_emitter +from open_webui.utils.plugin import get_function_module_from_cache +from open_webui.utils.models import get_all_models +from open_webui.utils.middleware import process_tool_result + +from open_webui.env import GLOBAL_LOG_LEVEL + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) + + +async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): + if "." in action_id: + action_id, sub_action_id = action_id.split(".") + else: + sub_action_id = None + + action = Functions.get_function_by_id(action_id) + if not action: + raise Exception(f"Action not found: {action_id}") + + if not request.app.state.MODELS: + await get_all_models(request, user=user) + + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + + data = form_data + model_id = data["model"] + + if model_id not in models: + raise Exception("Model not found") + model = models[model_id] + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + "user_id": user.id, + } + ) + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + "user_id": user.id, + } + ) + + function_module, _, _ = get_function_module_from_cache(request, action_id) + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": sub_action_id if sub_action_id is not None else action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__request__": request, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = user.model_dump() if isinstance(user, UserModel) else {} + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + log.exception(f"Failed to get user values: {e}") + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + # Process action result for Rich UI embeds (HTMLResponse, tuple with headers) + processed_result, _, action_embeds = process_tool_result( + request, + action_id, + data, + "action", + ) + + if action_embeds: + await __event_emitter__( + { + "type": "embeds", + "data": { + "embeds": action_embeds, + }, + } + ) + # Replace data with the processed status dict so we don't + # try to serialize the raw HTMLResponse / tuple back to the client + data = processed_result + + except Exception as e: + raise Exception(f"Error: {e}") + + return data diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index c5da749e14..9f6f5a9d5b 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -6,7 +6,7 @@ from aiocache import cached from typing import Any, Optional import random import json -import inspect + import uuid import asyncio @@ -356,124 +356,3 @@ async def chat_completed(request: Request, form_data: dict, user: Any): except Exception as e: raise Exception(f"Error: {e}") - -async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): - if "." in action_id: - action_id, sub_action_id = action_id.split(".") - else: - sub_action_id = None - - action = Functions.get_function_by_id(action_id) - if not action: - raise Exception(f"Action not found: {action_id}") - - if not request.app.state.MODELS: - await get_all_models(request, user=user) - - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): - models = { - request.state.model["id"]: request.state.model, - } - else: - models = request.app.state.MODELS - - data = form_data - model_id = data["model"] - - if model_id not in models: - raise Exception("Model not found") - model = models[model_id] - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - "user_id": user.id, - } - ) - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - "user_id": user.id, - } - ) - - function_module, _, _ = get_function_module_from_cache(request, action_id) - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(action_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - - if hasattr(function_module, "action"): - try: - action = function_module.action - - # Get the signature of the function - sig = inspect.signature(action) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": sub_action_id if sub_action_id is not None else action_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__request__": request, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = user.model_dump() if isinstance(user, UserModel) else {} - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id - ) - ) - except Exception as e: - log.exception(f"Failed to get user values: {e}") - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(action): - data = await action(**params) - else: - data = action(**params) - - # Process action result for Rich UI embeds (HTMLResponse, tuple with headers) - # Deferred import to avoid circular dependency (middleware imports chat) - from open_webui.utils.middleware import process_tool_result - - processed_result, _, action_embeds = process_tool_result( - request, - action_id, - data, - "action", - ) - - if action_embeds: - await __event_emitter__( - { - "type": "embeds", - "data": { - "embeds": action_embeds, - }, - } - ) - # Replace data with the processed status dict so we don't - # try to serialize the raw HTMLResponse / tuple back to the client - data = processed_result - - except Exception as e: - raise Exception(f"Error: {e}") - - return data