mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
Merge remote-tracking branch 'upstream/dev' into feat/oauth
This commit is contained in:
370
backend/main.py
370
backend/main.py
@@ -62,9 +62,7 @@ from apps.webui.models.functions import Functions
|
||||
from apps.webui.models.users import Users
|
||||
|
||||
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
from utils.misc import parse_duration
|
||||
from utils.utils import (
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
@@ -82,6 +80,7 @@ from utils.misc import (
|
||||
get_last_user_message,
|
||||
add_or_update_system_message,
|
||||
stream_message_template,
|
||||
parse_duration,
|
||||
)
|
||||
|
||||
from apps.rag.utils import get_rag_context, rag_template
|
||||
@@ -113,6 +112,7 @@ from config import (
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
SAFE_MODE,
|
||||
OAUTH_PROVIDERS,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
||||
@@ -124,6 +124,11 @@ from config import (
|
||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from utils.webhook import post_webhook
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
Functions.deactivate_all_functions()
|
||||
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
@@ -271,7 +276,7 @@ async def get_function_call_response(
|
||||
if tool_id in webui_app.state.TOOLS:
|
||||
toolkit_module = webui_app.state.TOOLS[tool_id]
|
||||
else:
|
||||
toolkit_module = load_toolkit_module_by_id(tool_id)
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||
|
||||
file_handler = False
|
||||
@@ -280,6 +285,14 @@ async def get_function_call_response(
|
||||
file_handler = True
|
||||
print("file_handler: ", file_handler)
|
||||
|
||||
if hasattr(toolkit_module, "valves") and hasattr(
|
||||
toolkit_module, "Valves"
|
||||
):
|
||||
valves = Tools.get_tool_valves_by_id(tool_id)
|
||||
toolkit_module.valves = toolkit_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
function = getattr(toolkit_module, result["name"])
|
||||
function_result = None
|
||||
try:
|
||||
@@ -289,16 +302,24 @@ async def get_function_call_response(
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
# Call the function with the '__user__' parameter included
|
||||
params = {
|
||||
**params,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(toolkit_module, "UserValves"):
|
||||
__user__["valves"] = toolkit_module.UserValves(
|
||||
**Tools.get_user_valves_by_id_and_user_id(
|
||||
tool_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
if "__messages__" in sig.parameters:
|
||||
# Call the function with the '__messages__' parameter included
|
||||
params = {
|
||||
@@ -386,54 +407,94 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
return (function.valves if function.valves else {}).get(
|
||||
"priority", 0
|
||||
)
|
||||
return 0
|
||||
|
||||
filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type(
|
||||
"filter", active_only=True
|
||||
)
|
||||
]
|
||||
# Check if the model has any filters
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
for filter_id in model["info"]["meta"].get("filterIds", []):
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type = load_function_module_by_id(
|
||||
filter_id
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
filter_ids.sort(key=get_priority)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = (
|
||||
load_function_module_by_id(filter_id)
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "inlet"):
|
||||
inlet = function_module.inlet
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
data = await inlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
else:
|
||||
data = inlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
if hasattr(function_module, "valves") and hasattr(
|
||||
function_module, "Valves"
|
||||
):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
try:
|
||||
if hasattr(function_module, "inlet"):
|
||||
inlet = function_module.inlet
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(inlet)
|
||||
params = {"body": data}
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if "__id__" in sig.parameters:
|
||||
params = {
|
||||
**params,
|
||||
"__id__": filter_id,
|
||||
}
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
data = await inlet(**params)
|
||||
else:
|
||||
data = inlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
# Set the task model
|
||||
task_model_id = data["model"]
|
||||
@@ -857,12 +918,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
|
||||
pipe = model.get("pipe")
|
||||
if pipe:
|
||||
form_data["user"] = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
async def job():
|
||||
pipe_id = form_data["model"]
|
||||
@@ -870,14 +925,62 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
|
||||
print(pipe_id)
|
||||
|
||||
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
|
||||
# Check if function is already loaded
|
||||
if pipe_id not in webui_app.state.FUNCTIONS:
|
||||
function_module, function_type, frontmatter = (
|
||||
load_function_module_by_id(pipe_id)
|
||||
)
|
||||
webui_app.state.FUNCTIONS[pipe_id] = function_module
|
||||
else:
|
||||
function_module = webui_app.state.FUNCTIONS[pipe_id]
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(
|
||||
function_module, "Valves"
|
||||
):
|
||||
|
||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
pipe = function_module.pipe
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(pipe)
|
||||
params = {"body": form_data}
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
pipe_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if form_data["stream"]:
|
||||
|
||||
async def stream_content():
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
res = await pipe(body=form_data)
|
||||
else:
|
||||
res = pipe(body=form_data)
|
||||
try:
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
res = await pipe(**params)
|
||||
else:
|
||||
res = pipe(**params)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
|
||||
return
|
||||
|
||||
if isinstance(res, str):
|
||||
message = stream_message_template(form_data["model"], res)
|
||||
@@ -922,10 +1025,20 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
stream_content(), media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
|
||||
try:
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
res = await pipe(**params)
|
||||
else:
|
||||
res = pipe(**params)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return {"error": {"detail": str(e)}}
|
||||
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
res = await pipe(body=form_data)
|
||||
res = await pipe(**params)
|
||||
else:
|
||||
res = pipe(body=form_data)
|
||||
res = pipe(**params)
|
||||
|
||||
if isinstance(res, dict):
|
||||
return res
|
||||
@@ -1008,7 +1121,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": {"id": user.id, "name": user.name, "role": user.role},
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
},
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
@@ -1033,49 +1151,88 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
else:
|
||||
pass
|
||||
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
# Check if the model has any filters
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
for filter_id in model["info"]["meta"].get("filterIds", []):
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type = load_function_module_by_id(
|
||||
filter_id
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "outlet"):
|
||||
outlet = function_module.outlet
|
||||
if inspect.iscoroutinefunction(outlet):
|
||||
data = await outlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
else:
|
||||
data = outlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
# Sort filter_ids by priority, using the get_priority function
|
||||
filter_ids.sort(key=get_priority)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = (
|
||||
load_function_module_by_id(filter_id)
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(
|
||||
function_module, "Valves"
|
||||
):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "outlet"):
|
||||
outlet = function_module.outlet
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(outlet)
|
||||
params = {"body": data}
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if "__id__" in sig.parameters:
|
||||
params = {
|
||||
**params,
|
||||
"__id__": filter_id,
|
||||
}
|
||||
|
||||
if inspect.iscoroutinefunction(outlet):
|
||||
data = await outlet(**params)
|
||||
else:
|
||||
data = outlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -1989,7 +2146,6 @@ async def get_manifest_json():
|
||||
"start_url": "/",
|
||||
"display": "standalone",
|
||||
"background_color": "#343541",
|
||||
"theme_color": "#343541",
|
||||
"orientation": "portrait-primary",
|
||||
"icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user