feat: switch to config proxy, remove config_get/set

This commit is contained in:
Jun Siang Cheah
2024-05-10 15:03:24 +08:00
parent f712c90019
commit 298e6848b3
11 changed files with 340 additions and 379 deletions

View File

@@ -26,8 +26,7 @@ from config import (
CACHE_DIR,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
config_set,
config_get,
AppConfig,
)
from typing import List, Optional
@@ -47,11 +46,13 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {}
@@ -77,34 +78,32 @@ class KeysUpdateForm(BaseModel):
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)}
return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
await get_all_models()
config_set(app.state.OPENAI_API_BASE_URLS, form_data.urls)
return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)}
app.state.config.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)}
return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
config_set(app.state.OPENAI_API_KEYS, form_data.keys)
return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)}
app.state.config.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
idx = None
try:
idx = config_get(app.state.OPENAI_API_BASE_URLS).index(
"https://api.openai.com/v1"
)
idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
body = await request.body()
name = hashlib.sha256(body).hexdigest()
@@ -118,15 +117,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
headers = {}
headers["Authorization"] = (
f"Bearer {config_get(app.state.OPENAI_API_KEYS)[idx]}"
)
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.post(
url=f"{config_get(app.state.OPENAI_API_BASE_URLS)[idx]}/audio/speech",
url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
data=body,
headers=headers,
stream=True,
@@ -187,7 +184,7 @@ def merge_models_lists(model_lists):
{**model, "urlIdx": idx}
for model in models
if "api.openai.com"
not in config_get(app.state.OPENAI_API_BASE_URLS)[idx]
not in app.state.config.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"]
]
)
@@ -199,14 +196,14 @@ async def get_all_models():
log.info("get_all_models()")
if (
len(config_get(app.state.OPENAI_API_KEYS)) == 1
and config_get(app.state.OPENAI_API_KEYS)[0] == ""
len(app.state.config.OPENAI_API_KEYS) == 1
and app.state.config.OPENAI_API_KEYS[0] == ""
):
models = {"data": []}
else:
tasks = [
fetch_url(f"{url}/models", config_get(app.state.OPENAI_API_KEYS)[idx])
for idx, url in enumerate(config_get(app.state.OPENAI_API_BASE_URLS))
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
@@ -238,19 +235,18 @@ async def get_all_models():
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None:
models = await get_all_models()
if config_get(app.state.ENABLE_MODEL_FILTER):
if app.state.ENABLE_MODEL_FILTER:
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"]
in config_get(app.state.MODEL_FILTER_LIST),
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
models["data"],
)
)
return models
return models
else:
url = config_get(app.state.OPENAI_API_BASE_URLS)[url_idx]
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
r = None
@@ -314,8 +310,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e)
url = config_get(app.state.OPENAI_API_BASE_URLS)[idx]
key = config_get(app.state.OPENAI_API_KEYS)[idx]
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"