This commit is contained in:
Timothy Jaeryang Baek
2024-12-11 18:53:38 -08:00
parent ccdf51588e
commit 772f5ccd60
3 changed files with 180 additions and 246 deletions

View File

@@ -697,11 +697,11 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
if filter_id in app.state.FUNCTIONS:
function_module = app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
@@ -828,7 +828,7 @@ async def chat_completion_tools_handler(
models,
)
tools = get_tools(
webui_app,
app,
tool_ids,
user,
{
@@ -1406,7 +1406,7 @@ async def commit_session_after_request(request: Request, call_next):
@app.middleware("http")
async def check_url(request: Request, call_next):
start_time = int(time.time())
request.state.enable_api_key = webui_app.state.config.ENABLE_API_KEY
request.state.enable_api_key = app.state.config.ENABLE_API_KEY
response = await call_next(request)
process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time)
@@ -1913,11 +1913,11 @@ async def get_all_models():
]
def get_function_module_by_id(function_id):
if function_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[function_id]
if function_id in app.state.FUNCTIONS:
function_module = app.state.FUNCTIONS[function_id]
else:
function_module, _, _ = load_function_module_by_id(function_id)
webui_app.state.FUNCTIONS[function_id] = function_module
app.state.FUNCTIONS[function_id] = function_module
for model in models:
action_ids = [
@@ -1953,7 +1953,7 @@ async def get_models(user=Depends(get_verified_user)):
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
]
model_order_list = webui_app.state.config.MODEL_ORDER_LIST
model_order_list = app.state.config.MODEL_ORDER_LIST
if model_order_list:
model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)}
# Sort models by order list priority, with fallback for those not in the list
@@ -2229,11 +2229,11 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
if filter_id in app.state.FUNCTIONS:
function_module = app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
@@ -2340,11 +2340,11 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
}
)
if action_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[action_id]
if action_id in app.state.FUNCTIONS:
function_module = app.state.FUNCTIONS[action_id]
else:
function_module, _, _ = load_function_module_by_id(action_id)
webui_app.state.FUNCTIONS[action_id] = function_module
app.state.FUNCTIONS[action_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
@@ -2448,17 +2448,17 @@ async def get_app_config(request: Request):
},
"features": {
"auth": WEBUI_AUTH,
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_ldap": webui_app.state.config.ENABLE_LDAP,
"enable_api_key": webui_app.state.config.ENABLE_API_KEY,
"enable_signup": webui_app.state.config.ENABLE_SIGNUP,
"enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM,
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_ldap": app.state.config.ENABLE_LDAP,
"enable_api_key": app.state.config.ENABLE_API_KEY,
"enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
**(
{
"enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_image_generation": images_app.state.config.ENABLED,
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING,
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
}
@@ -2468,8 +2468,8 @@ async def get_app_config(request: Request):
},
**(
{
"default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"audio": {
"tts": {
"engine": audio_app.state.config.TTS_ENGINE,
@@ -2484,7 +2484,7 @@ async def get_app_config(request: Request):
"max_size": retrieval_app.state.config.FILE_MAX_SIZE,
"max_count": retrieval_app.state.config.FILE_MAX_COUNT,
},
"permissions": {**webui_app.state.config.USER_PERMISSIONS},
"permissions": {**app.state.config.USER_PERMISSIONS},
}
if user is not None
else {}
@@ -2506,7 +2506,7 @@ async def get_webhook_url(user=Depends(get_admin_user)):
@app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.config.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return {"url": app.state.config.WEBHOOK_URL}