This commit is contained in:
Timothy Jaeryang Baek
2025-04-05 04:05:52 -06:00
parent 0f310b3509
commit 0c0505e1cd
12 changed files with 613 additions and 368 deletions

View File

@@ -2,9 +2,10 @@ import inspect
import logging
import re
import inspect
import uuid
import aiohttp
import asyncio
from typing import Any, Awaitable, Callable, get_type_hints
from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union
from functools import update_wrapper, partial
@@ -217,3 +218,260 @@ def get_tools_specs(tool_class: object) -> list[dict]:
function_list = get_callable_attributes(tool_class)
models = map(function_to_pydantic_model, function_list)
return [convert_to_openai_function(tool) for tool in models]
import copy
def resolve_schema(schema, components):
"""
Recursively resolves a JSON schema using OpenAPI components.
"""
if not schema:
return {}
if "$ref" in schema:
ref_path = schema["$ref"]
ref_parts = ref_path.strip("#/").split("/")
resolved = components
for part in ref_parts[1:]: # Skip the initial 'components'
resolved = resolved.get(part, {})
return resolve_schema(resolved, components)
resolved_schema = copy.deepcopy(schema)
# Recursively resolve inner schemas
if "properties" in resolved_schema:
for prop, prop_schema in resolved_schema["properties"].items():
resolved_schema["properties"][prop] = resolve_schema(
prop_schema, components
)
if "items" in resolved_schema:
resolved_schema["items"] = resolve_schema(resolved_schema["items"], components)
return resolved_schema
def convert_openapi_to_tool_payload(openapi_spec):
"""
Converts an OpenAPI specification into a custom tool payload structure.
Args:
openapi_spec (dict): The OpenAPI specification as a Python dict.
Returns:
list: A list of tool payloads.
"""
tool_payload = []
for path, methods in openapi_spec.get("paths", {}).items():
for method, operation in methods.items():
tool = {
"type": "function",
"name": operation.get("operationId"),
"description": operation.get("summary", "No description available."),
"parameters": {"type": "object", "properties": {}, "required": []},
}
# Extract path and query parameters
for param in operation.get("parameters", []):
param_name = param["name"]
param_schema = param.get("schema", {})
tool["parameters"]["properties"][param_name] = {
"type": param_schema.get("type"),
"description": param_schema.get("description", ""),
}
if param.get("required"):
tool["parameters"]["required"].append(param_name)
# Extract and resolve requestBody if available
request_body = operation.get("requestBody")
if request_body:
content = request_body.get("content", {})
json_schema = content.get("application/json", {}).get("schema")
if json_schema:
resolved_schema = resolve_schema(
json_schema, openapi_spec.get("components", {})
)
if resolved_schema.get("properties"):
tool["parameters"]["properties"].update(
resolved_schema["properties"]
)
if "required" in resolved_schema:
tool["parameters"]["required"] = list(
set(
tool["parameters"]["required"]
+ resolved_schema["required"]
)
)
elif resolved_schema.get("type") == "array":
tool["parameters"] = resolved_schema # special case for array
tool_payload.append(tool)
return tool_payload
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if token:
headers["Authorization"] = f"Bearer {token}"
error = None
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status != 200:
error_body = await response.json()
raise Exception(error_body)
res = await response.json()
except Exception as err:
print("Error:", err)
if isinstance(err, dict) and "detail" in err:
error = err["detail"]
else:
error = str(err)
raise Exception(error)
data = {
"openapi": res,
"info": res.get("info", {}),
"specs": convert_openapi_to_tool_payload(res),
}
print("Fetched data:", data)
return data
async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
enabled_servers = [
server for server in servers if server.get("config", {}).get("enable")
]
urls = [
(
server,
f"{server.get('url')}/{server.get('path', 'openapi.json')}",
server.get("key", ""),
)
for server in enabled_servers
]
tasks = [get_tool_server_data(token, url) for _, url, token in urls]
results: List[Dict[str, Any]] = []
responses = await asyncio.gather(*tasks, return_exceptions=True)
for (server, _, _), response in zip(urls, responses):
if isinstance(response, Exception):
url_path = server.get("path", "openapi.json")
full_url = f"{server.get('url')}/{url_path}"
print(f"Failed to connect to {full_url} OpenAPI tool server")
else:
results.append(
{
"url": server.get("url"),
"openapi": response["openapi"],
"info": response["info"],
"specs": response["specs"],
}
)
return results
async def execute_tool_server(
token: str, url: str, name: str, params: Dict[str, Any], server_data: Dict[str, Any]
) -> Any:
error = None
try:
openapi = server_data.get("openapi", {})
paths = openapi.get("paths", {})
matching_route = None
for route_path, methods in paths.items():
for http_method, operation in methods.items():
if isinstance(operation, dict) and operation.get("operationId") == name:
matching_route = (route_path, methods)
break
if matching_route:
break
if not matching_route:
raise Exception(f"No matching route found for operationId: {name}")
route_path, methods = matching_route
method_entry = None
for http_method, operation in methods.items():
if operation.get("operationId") == name:
method_entry = (http_method.lower(), operation)
break
if not method_entry:
raise Exception(f"No matching method found for operationId: {name}")
http_method, operation = method_entry
path_params = {}
query_params = {}
body_params = {}
for param in operation.get("parameters", []):
param_name = param["name"]
param_in = param["in"]
if param_name in params:
if param_in == "path":
path_params[param_name] = params[param_name]
elif param_in == "query":
query_params[param_name] = params[param_name]
final_url = f"{url}{route_path}"
for key, value in path_params.items():
final_url = final_url.replace(f"{{{key}}}", str(value))
if query_params:
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
final_url = f"{final_url}?{query_string}"
if operation.get("requestBody", {}).get("content"):
if params:
body_params = params
else:
raise Exception(
f"Request body expected for operation '{name}' but none found."
)
headers = {"Content-Type": "application/json"}
if token:
headers["Authorization"] = f"Bearer {token}"
async with aiohttp.ClientSession() as session:
request_method = getattr(session, http_method.lower())
if http_method in ["post", "put", "patch"]:
async with request_method(
final_url, json=body_params, headers=headers
) as response:
if response.status >= 400:
text = await response.text()
raise Exception(f"HTTP error {response.status}: {text}")
return await response.json()
else:
async with request_method(final_url, headers=headers) as response:
if response.status >= 400:
text = await response.text()
raise Exception(f"HTTP error {response.status}: {text}")
return await response.json()
except Exception as err:
error = str(err)
print("API Request Error:", error)
return {"error": error}