refac: tool servers

This commit is contained in:
Timothy Jaeryang Baek
2025-04-05 04:40:01 -06:00
parent 61778c8f71
commit 9747a0e1f1
4 changed files with 95 additions and 40 deletions

View File

@@ -5,7 +5,7 @@ import inspect
import aiohttp
import asyncio
from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union
from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union, Optional
from functools import update_wrapper, partial
@@ -348,40 +348,47 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
return data
async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
enabled_servers = [
server for server in servers if server.get("config", {}).get("enable")
]
urls = [
(
server,
f"{server.get('url')}/{server.get('path', 'openapi.json')}",
server.get("key", ""),
)
for server in enabled_servers
]
tasks = [get_tool_server_data(token, url) for _, url, token in urls]
results: List[Dict[str, Any]] = []
responses = await asyncio.gather(*tasks, return_exceptions=True)
for (server, _, _), response in zip(urls, responses):
if isinstance(response, Exception):
async def get_tool_servers_data(
servers: List[Dict[str, Any]], session_token: Optional[str] = None
) -> List[Dict[str, Any]]:
# Prepare list of enabled servers along with their original index
server_entries = []
for idx, server in enumerate(servers):
if server.get("config", {}).get("enable"):
url_path = server.get("path", "openapi.json")
full_url = f"{server.get('url')}/{url_path}"
print(f"Failed to connect to {full_url} OpenAPI tool server")
else:
results.append(
{
"url": server.get("url"),
"openapi": response["openapi"],
"info": response["info"],
"specs": response["specs"],
}
)
auth_type = server.get("auth_type", "bearer")
token = None
if auth_type == "bearer":
token = server.get("key", "")
elif auth_type == "session":
token = session_token
server_entries.append((idx, server, full_url, token))
# Create async tasks to fetch data
tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries]
# Execute tasks concurrently
responses = await asyncio.gather(*tasks, return_exceptions=True)
# Build final results with index and server metadata
results = []
for (idx, server, url, _), response in zip(server_entries, responses):
if isinstance(response, Exception):
print(f"Failed to connect to {url} OpenAPI tool server")
continue
results.append(
{
"idx": idx,
"url": server.get("url"),
"openapi": response.get("openapi"),
"info": response.get("info"),
"specs": response.get("specs"),
}
)
return results