This commit is contained in:
Timothy Jaeryang Baek
2024-12-11 20:15:23 -08:00
parent fe5519e0a2
commit a07ff56c50
2 changed files with 55 additions and 50 deletions

View File

@@ -245,41 +245,6 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{
**model,
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"urlIdx": idx,
}
for model in models
if "api.openai.com"
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
)
return merged_list
async def get_all_models_responses(request: Request) -> list:
if not request.app.state.config.ENABLE_OPENAI_API:
return []
@@ -379,7 +344,7 @@ async def get_all_models(request: Request) -> dict[str, list]:
if not request.app.state.config.ENABLE_OPENAI_API:
return {"data": []}
responses = await get_all_models_responses()
responses = await get_all_models_responses(request)
def extract_data(response):
if response and "data" in response:
@@ -388,6 +353,40 @@ async def get_all_models(request: Request) -> dict[str, list]:
return response
return None
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{
**model,
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"urlIdx": idx,
}
for model in models
if "api.openai.com"
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
)
return merged_list
models = {"data": merge_models_lists(map(extract_data, responses))}
log.debug(f"models: {models}")