mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
ensure from config (#19902)
This commit is contained in:
@@ -451,6 +451,50 @@ class OAuthClientManager:
|
|||||||
}
|
}
|
||||||
return self.clients[client_id]
|
return self.clients[client_id]
|
||||||
|
|
||||||
|
def ensure_client_from_config(self, client_id):
|
||||||
|
"""
|
||||||
|
Lazy-load an OAuth client from the current TOOL_SERVER_CONNECTIONS
|
||||||
|
config if it hasn't been registered on this node yet.
|
||||||
|
"""
|
||||||
|
if client_id in self.clients:
|
||||||
|
return self.clients[client_id]["client"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
connections = getattr(self.app.state.config, "TOOL_SERVER_CONNECTIONS", [])
|
||||||
|
except Exception:
|
||||||
|
connections = []
|
||||||
|
|
||||||
|
for connection in connections or []:
|
||||||
|
if connection.get("type", "openapi") != "mcp":
|
||||||
|
continue
|
||||||
|
if connection.get("auth_type", "none") != "oauth_2.1":
|
||||||
|
continue
|
||||||
|
|
||||||
|
server_id = connection.get("info", {}).get("id")
|
||||||
|
if not server_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
expected_client_id = f"mcp:{server_id}"
|
||||||
|
if client_id != expected_client_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
oauth_client_info = connection.get("info", {}).get("oauth_client_info", "")
|
||||||
|
if not oauth_client_info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
oauth_client_info = decrypt_data(oauth_client_info)
|
||||||
|
return self.add_client(
|
||||||
|
expected_client_id, OAuthClientInformationFull(**oauth_client_info)
|
||||||
|
)["client"]
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to lazily add OAuth client {expected_client_id} from config: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def remove_client(self, client_id):
|
def remove_client(self, client_id):
|
||||||
if client_id in self.clients:
|
if client_id in self.clients:
|
||||||
del self.clients[client_id]
|
del self.clients[client_id]
|
||||||
@@ -718,10 +762,13 @@ class OAuthClientManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
|
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
|
||||||
client = self.get_client(client_id)
|
client = self.get_client(client_id) or self.ensure_client_from_config(client_id)
|
||||||
if client is None:
|
if client is None:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
client_info = self.get_client_info(client_id)
|
client_info = self.get_client_info(client_id)
|
||||||
|
if client_info is None:
|
||||||
|
# ensure_client_from_config registers client_info too
|
||||||
|
client_info = self.get_client_info(client_id)
|
||||||
if client_info is None:
|
if client_info is None:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
|
|
||||||
@@ -732,7 +779,7 @@ class OAuthClientManager:
|
|||||||
return await client.authorize_redirect(request, redirect_uri_str)
|
return await client.authorize_redirect(request, redirect_uri_str)
|
||||||
|
|
||||||
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
||||||
client = self.get_client(client_id)
|
client = self.get_client(client_id) or self.ensure_client_from_config(client_id)
|
||||||
if client is None:
|
if client is None:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user