Refactor common code between inlet and outlet

This commit is contained in:
Xingjian Xie
2025-02-06 23:01:43 +00:00
parent 14398ab628
commit 89669a21fc
3 changed files with 143 additions and 203 deletions

View File

@@ -68,6 +68,10 @@ from open_webui.utils.misc import (
)
from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
)
from open_webui.tasks import create_task
@@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def chat_completion_filter_functions_handler(request, body, model, extra_params):
skip_files = None
def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [
function.id for function in Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
return filter_ids
filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
# Apply valves to the function
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if hasattr(function_module, "inlet"):
try:
inlet = function_module.inlet
# Create a dictionary of parameters to be passed to the function
params = {"body": body} | {
k: v
for k, v in {
**extra_params,
"__model__": model,
"__id__": filter_id,
}.items()
if k in inspect.signature(inlet).parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"):
try:
params["__user__"]["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, params["__user__"]["id"]
)
)
except Exception as e:
print(e)
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
body = inlet(**params)
except Exception as e:
print(f"Error: {e}")
raise e
if skip_files and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
return body, {}
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
@@ -782,8 +693,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
)
try:
form_data, flags = await chat_completion_filter_functions_handler(
request, form_data, model, extra_params
form_data, flags = await process_filter_functions(
handler_type="inlet",
filter_ids=get_sorted_filter_ids(model),
request=request,
data=form_data,
extra_params=extra_params,
)
except Exception as e:
raise Exception(f"Error: {e}")