mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
refac: tool name collision handling
This commit is contained in:
@@ -5,6 +5,7 @@ import inspect
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import yaml
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
@@ -85,7 +86,9 @@ async def get_tools(
|
||||
tool_server_data = server
|
||||
break
|
||||
|
||||
assert tool_server_data is not None
|
||||
if tool_server_data is None:
|
||||
log.warning(f"Tool server data not found for {server_id}")
|
||||
continue
|
||||
|
||||
tool_server_idx = tool_server_data.get("idx", 0)
|
||||
tool_server_connection = (
|
||||
@@ -131,14 +134,15 @@ async def get_tools(
|
||||
"spec": spec,
|
||||
}
|
||||
|
||||
# TODO: if collision, prepend toolkit name
|
||||
if function_name in tools_dict:
|
||||
# Handle function name collisions
|
||||
while function_name in tools_dict:
|
||||
log.warning(
|
||||
f"Tool {function_name} already exists in another tools!"
|
||||
)
|
||||
log.warning(f"Discarding {tool_id}.{function_name}")
|
||||
else:
|
||||
tools_dict[function_name] = tool_dict
|
||||
# Prepend server ID to function name
|
||||
function_name = f"{server_id}_{function_name}"
|
||||
|
||||
tools_dict[function_name] = tool_dict
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
@@ -198,14 +202,15 @@ async def get_tools(
|
||||
},
|
||||
}
|
||||
|
||||
# TODO: if collision, prepend toolkit name
|
||||
if function_name in tools_dict:
|
||||
# Handle function name collisions
|
||||
while function_name in tools_dict:
|
||||
log.warning(
|
||||
f"Tool {function_name} already exists in another tools!"
|
||||
)
|
||||
log.warning(f"Discarding {tool_id}.{function_name}")
|
||||
else:
|
||||
tools_dict[function_name] = tool_dict
|
||||
# Prepend tool ID to function name
|
||||
function_name = f"{tool_id}_{function_name}"
|
||||
|
||||
tools_dict[function_name] = tool_dict
|
||||
|
||||
return tools_dict
|
||||
|
||||
@@ -453,8 +458,8 @@ async def set_tool_servers(request: Request):
|
||||
)
|
||||
|
||||
if request.app.state.redis is not None:
|
||||
await request.app.state.redis.hmset(
|
||||
"tool_servers", request.app.state.TOOL_SERVERS
|
||||
await request.app.state.redis.set(
|
||||
"tool_servers", json.dumps(request.app.state.TOOL_SERVERS)
|
||||
)
|
||||
|
||||
return request.app.state.TOOL_SERVERS
|
||||
@@ -463,7 +468,10 @@ async def set_tool_servers(request: Request):
|
||||
async def get_tool_servers(request: Request):
|
||||
tool_servers = []
|
||||
if request.app.state.redis is not None:
|
||||
tool_servers = await request.app.state.redis.hgetall("tool_servers")
|
||||
try:
|
||||
tool_servers = json.loads(await request.app.state.redis.get("tool_servers"))
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching tool_servers from Redis: {e}")
|
||||
|
||||
if not tool_servers:
|
||||
await set_tool_servers(request)
|
||||
@@ -536,7 +544,10 @@ async def get_tool_servers_data(
|
||||
elif auth_type == "session":
|
||||
token = session_token
|
||||
|
||||
id = info.get("id", idx)
|
||||
id = info.get("id")
|
||||
if not id:
|
||||
id = str(idx)
|
||||
|
||||
server_entries.append((id, idx, server, full_url, info, token))
|
||||
|
||||
# Create async tasks to fetch data
|
||||
|
||||
Reference in New Issue
Block a user