enh: connection tags

This commit is contained in:
Timothy Jaeryang Baek
2025-03-11 20:37:30 +00:00
parent b427f506f6
commit c309412980
7 changed files with 88 additions and 29 deletions

View File

@@ -965,14 +965,24 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models
models = await get_all_models(request, user=user)
all_models = await get_all_models(request, user=user)
# Filter out filter pipelines
models = [
model
for model in models
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
]
models = []
for model in all_models:
# Filter out filter pipelines
if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
continue
model_tags = [
tag.get("name")
for tag in model.get("info", {}).get("meta", {}).get("tags", [])
]
tags = [tag.get("name") for tag in model.get("tags", [])]
tags = list(set(model_tags + tags))
model["tags"] = [{"name": tag} for tag in tags]
models.append(model)
model_order_list = request.app.state.config.MODEL_ORDER_LIST
if model_order_list:

View File

@@ -295,7 +295,7 @@ async def update_config(
}
@cached(ttl=3)
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API:
@@ -336,6 +336,7 @@ async def get_all_models(request: Request, user: UserModel = None):
)
prefix_id = api_config.get("prefix_id", None)
tags = api_config.get("tags", [])
model_ids = api_config.get("model_ids", [])
if len(model_ids) != 0 and "models" in response:
@@ -350,6 +351,10 @@ async def get_all_models(request: Request, user: UserModel = None):
for model in response.get("models", []):
model["model"] = f"{prefix_id}.{model['model']}"
if tags:
for model in response.get("models", []):
model["tags"] = tags
def merge_models_lists(model_lists):
merged_models = {}

View File

@@ -353,6 +353,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
)
prefix_id = api_config.get("prefix_id", None)
tags = api_config.get("tags", [])
if prefix_id:
for model in (
@@ -360,6 +361,12 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
):
model["id"] = f"{prefix_id}.{model['id']}"
if tags:
for model in (
response if isinstance(response, list) else response.get("data", [])
):
model["tags"] = tags
log.debug(f"get_all_models:responses() {responses}")
return responses
@@ -377,7 +384,7 @@ async def get_filtered_models(models, user):
return filtered_models
@cached(ttl=3)
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()")

View File

@@ -49,6 +49,7 @@ async def get_all_base_models(request: Request, user: UserModel = None):
"created": int(time.time()),
"owned_by": "ollama",
"ollama": model,
"tags": model.get("tags", []),
}
for model in ollama_models["models"]
]