mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
Refactor common code between inlet and outlet
This commit is contained in:
@@ -68,6 +68,10 @@ from open_webui.utils.misc import (
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.tasks import create_task
|
||||
@@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def chat_completion_filter_functions_handler(request, body, model, extra_params):
|
||||
skip_files = None
|
||||
|
||||
def get_filter_function_ids(model):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [
|
||||
function.id for function in Functions.get_global_filter_functions()
|
||||
]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
filter_ids.sort(key=get_priority)
|
||||
return filter_ids
|
||||
|
||||
filter_ids = get_filter_function_ids(model)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
# Apply valves to the function
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
if hasattr(function_module, "inlet"):
|
||||
try:
|
||||
inlet = function_module.inlet
|
||||
|
||||
# Create a dictionary of parameters to be passed to the function
|
||||
params = {"body": body} | {
|
||||
k: v
|
||||
for k, v in {
|
||||
**extra_params,
|
||||
"__model__": model,
|
||||
"__id__": filter_id,
|
||||
}.items()
|
||||
if k in inspect.signature(inlet).parameters
|
||||
}
|
||||
|
||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, params["__user__"]["id"]
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
body = await inlet(**params)
|
||||
else:
|
||||
body = inlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
raise e
|
||||
|
||||
if skip_files and "files" in body.get("metadata", {}):
|
||||
del body["metadata"]["files"]
|
||||
|
||||
return body, {}
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
request: Request, body: dict, user: UserModel, models, tools
|
||||
) -> tuple[dict, dict]:
|
||||
@@ -782,8 +693,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_filter_functions_handler(
|
||||
request, form_data, model, extra_params
|
||||
form_data, flags = await process_filter_functions(
|
||||
handler_type="inlet",
|
||||
filter_ids=get_sorted_filter_ids(model),
|
||||
request=request,
|
||||
data=form_data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error: {e}")
|
||||
|
||||
Reference in New Issue
Block a user