mirror of
https://github.com/open-webui/open-webui.git
synced 2026-02-24 04:00:31 +01:00
refac
This commit is contained in:
@@ -509,8 +509,8 @@ from open_webui.utils.models import (
|
||||
from open_webui.utils.chat import (
|
||||
generate_chat_completion as chat_completion_handler,
|
||||
chat_completed as chat_completed_handler,
|
||||
chat_action as chat_action_handler,
|
||||
)
|
||||
from open_webui.utils.actions import chat_action as chat_action_handler
|
||||
from open_webui.utils.embeddings import generate_embeddings
|
||||
from open_webui.utils.middleware import (
|
||||
build_chat_response_context,
|
||||
|
||||
139
backend/open_webui/utils/actions.py
Normal file
139
backend/open_webui/utils/actions.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import logging
|
||||
import sys
|
||||
import inspect
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.functions import Functions
|
||||
|
||||
from open_webui.socket.main import get_event_call, get_event_emitter
|
||||
from open_webui.utils.plugin import get_function_module_from_cache
|
||||
from open_webui.utils.models import get_all_models
|
||||
from open_webui.utils.middleware import process_tool_result
|
||||
|
||||
from open_webui.env import GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
||||
if "." in action_id:
|
||||
action_id, sub_action_id = action_id.split(".")
|
||||
else:
|
||||
sub_action_id = None
|
||||
|
||||
action = Functions.get_function_by_id(action_id)
|
||||
if not action:
|
||||
raise Exception(f"Action not found: {action_id}")
|
||||
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
model = models[model_id]
|
||||
|
||||
__event_emitter__ = get_event_emitter(
|
||||
{
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
__event_call__ = get_event_call(
|
||||
{
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
|
||||
function_module, _, _ = get_function_module_from_cache(request, action_id)
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(action_id)
|
||||
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||
|
||||
if hasattr(function_module, "action"):
|
||||
try:
|
||||
action = function_module.action
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(action)
|
||||
params = {"body": data}
|
||||
|
||||
# Extra parameters to be passed to the function
|
||||
extra_params = {
|
||||
"__model__": model,
|
||||
"__id__": sub_action_id if sub_action_id is not None else action_id,
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__request__": request,
|
||||
}
|
||||
|
||||
# Add extra params in contained in function signature
|
||||
for key, value in extra_params.items():
|
||||
if key in sig.parameters:
|
||||
params[key] = value
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = user.model_dump() if isinstance(user, UserModel) else {}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
action_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to get user values: {e}")
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if inspect.iscoroutinefunction(action):
|
||||
data = await action(**params)
|
||||
else:
|
||||
data = action(**params)
|
||||
|
||||
# Process action result for Rich UI embeds (HTMLResponse, tuple with headers)
|
||||
processed_result, _, action_embeds = process_tool_result(
|
||||
request,
|
||||
action_id,
|
||||
data,
|
||||
"action",
|
||||
)
|
||||
|
||||
if action_embeds:
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "embeds",
|
||||
"data": {
|
||||
"embeds": action_embeds,
|
||||
},
|
||||
}
|
||||
)
|
||||
# Replace data with the processed status dict so we don't
|
||||
# try to serialize the raw HTMLResponse / tuple back to the client
|
||||
data = processed_result
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error: {e}")
|
||||
|
||||
return data
|
||||
@@ -6,7 +6,7 @@ from aiocache import cached
|
||||
from typing import Any, Optional
|
||||
import random
|
||||
import json
|
||||
import inspect
|
||||
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
@@ -356,124 +356,3 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
except Exception as e:
|
||||
raise Exception(f"Error: {e}")
|
||||
|
||||
|
||||
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
||||
if "." in action_id:
|
||||
action_id, sub_action_id = action_id.split(".")
|
||||
else:
|
||||
sub_action_id = None
|
||||
|
||||
action = Functions.get_function_by_id(action_id)
|
||||
if not action:
|
||||
raise Exception(f"Action not found: {action_id}")
|
||||
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
model = models[model_id]
|
||||
|
||||
__event_emitter__ = get_event_emitter(
|
||||
{
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
__event_call__ = get_event_call(
|
||||
{
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
|
||||
function_module, _, _ = get_function_module_from_cache(request, action_id)
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(action_id)
|
||||
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||
|
||||
if hasattr(function_module, "action"):
|
||||
try:
|
||||
action = function_module.action
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(action)
|
||||
params = {"body": data}
|
||||
|
||||
# Extra parameters to be passed to the function
|
||||
extra_params = {
|
||||
"__model__": model,
|
||||
"__id__": sub_action_id if sub_action_id is not None else action_id,
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__request__": request,
|
||||
}
|
||||
|
||||
# Add extra params in contained in function signature
|
||||
for key, value in extra_params.items():
|
||||
if key in sig.parameters:
|
||||
params[key] = value
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = user.model_dump() if isinstance(user, UserModel) else {}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
action_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to get user values: {e}")
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if inspect.iscoroutinefunction(action):
|
||||
data = await action(**params)
|
||||
else:
|
||||
data = action(**params)
|
||||
|
||||
# Process action result for Rich UI embeds (HTMLResponse, tuple with headers)
|
||||
# Deferred import to avoid circular dependency (middleware imports chat)
|
||||
from open_webui.utils.middleware import process_tool_result
|
||||
|
||||
processed_result, _, action_embeds = process_tool_result(
|
||||
request,
|
||||
action_id,
|
||||
data,
|
||||
"action",
|
||||
)
|
||||
|
||||
if action_embeds:
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "embeds",
|
||||
"data": {
|
||||
"embeds": action_embeds,
|
||||
},
|
||||
}
|
||||
)
|
||||
# Replace data with the processed status dict so we don't
|
||||
# try to serialize the raw HTMLResponse / tuple back to the client
|
||||
data = processed_result
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error: {e}")
|
||||
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user