mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
Merge branch 'dev' into feat/model-config
This commit is contained in:
@@ -21,6 +21,7 @@ from utils.utils import (
|
||||
)
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
ENABLE_OPENAI_API,
|
||||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
CACHE_DIR,
|
||||
@@ -47,12 +48,15 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
app.state.MODEL_CONFIG = MODEL_CONFIG.value.get("openai", [])
|
||||
|
||||
|
||||
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
|
||||
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
||||
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
|
||||
|
||||
@@ -70,6 +74,21 @@ async def check_url(request: Request, call_next):
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_config(user=Depends(get_admin_user)):
|
||||
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
|
||||
|
||||
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
enable_openai_api: Optional[bool] = None
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
|
||||
app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
|
||||
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
|
||||
|
||||
|
||||
class UrlsUpdateForm(BaseModel):
|
||||
urls: List[str]
|
||||
|
||||
@@ -166,11 +185,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
async def fetch_url(url, key):
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
return await response.json()
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
return await response.json()
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
log.error(f"Connection error: {e}")
|
||||
@@ -202,7 +225,7 @@ async def get_all_models():
|
||||
if (
|
||||
len(app.state.config.OPENAI_API_KEYS) == 1
|
||||
and app.state.config.OPENAI_API_KEYS[0] == ""
|
||||
):
|
||||
) or not app.state.config.ENABLE_OPENAI_API:
|
||||
models = {"data": []}
|
||||
else:
|
||||
tasks = [
|
||||
@@ -248,11 +271,11 @@ def add_custom_info_to_model(model: dict):
|
||||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
|
||||
if url_idx == None:
|
||||
models = await get_all_models()
|
||||
if app.state.ENABLE_MODEL_FILTER:
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["data"] = list(
|
||||
filter(
|
||||
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
|
||||
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
||||
models["data"],
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user