diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 5e0a74d612..aae86e49ed 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -90,8 +90,6 @@ async def send_get_request(url, key=None, user: UserModel = None): return None - - def openai_reasoning_model_handler(payload): """ Handle reasoning model specific parameters @@ -365,9 +363,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: request_tasks = [] for idx, url in enumerate(api_base_urls): - if (str(idx) not in api_configs) and ( - url not in api_configs # Legacy support - ): + if (str(idx) not in api_configs) and (url not in api_configs): # Legacy support request_tasks.append( send_get_request( f"{url}/models", @@ -378,9 +374,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: else: api_config = api_configs.get( str(idx), - api_configs.get( - url, {} - ), # Legacy support + api_configs.get(url, {}), # Legacy support ) enable = api_config.get("enable", True) @@ -423,9 +417,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: url = api_base_urls[idx] api_config = api_configs.get( str(idx), - api_configs.get( - url, {} - ), # Legacy support + api_configs.get(url, {}), # Legacy support ) connection_type = api_config.get("connection_type", "external") @@ -462,12 +454,8 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: async def get_filtered_models(models, user, db=None): # Filter models based on user access control model_ids = [model["id"] for model in models.get("data", [])] - model_infos = { - m.id: m for m in Models.get_models_by_ids(model_ids, db=db) - } - user_group_ids = { - g.id for g in Groups.get_groups_by_member_id(user.id, db=db) - } + model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)} + user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)} filtered_models = [] for model in models.get("data", []): @@ -532,11 +520,9 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: for model in model_list: model_id = model.get("id") or model.get("name") - if ( - "api.openai.com" - in api_base_urls[idx] - and not is_supported_openai_models(model_id) - ): + if "api.openai.com" in api_base_urls[ + idx + ] and not is_supported_openai_models(model_id): # Skip unwanted OpenAI models continue @@ -992,11 +978,11 @@ async def generate_chat_completion( ) # Check if model is already in app state cache to avoid expensive get_all_models() call - # This significantly reduces TTFT when models are already cached - model = request.app.state.OPENAI_MODELS.get(model_id) if request.app.state.OPENAI_MODELS else None - if not model: + models = request.app.state.OPENAI_MODELS + if not models or model not in models: await get_all_models(request, user=user) - model = request.app.state.OPENAI_MODELS.get(model_id) + models = request.app.state.OPENAI_MODELS + model = models.get(model_id) if model: idx = model["urlIdx"]