This commit is contained in:
Timothy Jaeryang Baek
2026-05-09 06:01:02 +09:00
parent d78c247036
commit 9907c0a25a
6 changed files with 58 additions and 13 deletions

View File

@@ -8,6 +8,7 @@ from typing import Optional
from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.headers import get_custom_headers
from open_webui.config import get_config, save_config, async_save_config
from open_webui.config import BannerModel
@@ -103,6 +104,7 @@ class OAuthClientRegistrationForm(BaseModel):
client_id: str
client_name: Optional[str] = None
client_secret: Optional[str] = None
oauth_server_url: Optional[str] = None
@router.post('/oauth/clients/register')
@@ -117,18 +119,20 @@ async def register_oauth_client(
if type:
oauth_client_id = f'{type}:{form_data.client_id}'
oauth_server_url = form_data.oauth_server_url if form_data.oauth_server_url else form_data.url
if form_data.client_secret:
# Static credentials: skip dynamic registration, build from provided credentials
oauth_client_info = await get_oauth_client_info_with_static_credentials(
request,
oauth_client_id,
form_data.url,
oauth_server_url,
oauth_client_id=form_data.client_id,
oauth_client_secret=form_data.client_secret,
)
else:
oauth_client_info = await get_oauth_client_info_with_dynamic_client_registration(
request, oauth_client_id, form_data.url
request, oauth_client_id, oauth_server_url
)
return {
'status': True,
@@ -155,6 +159,7 @@ class ToolServerConnection(BaseModel):
headers: Optional[dict | str] = None
key: Optional[str]
config: Optional[dict]
info: Optional[dict] = None
model_config = ConfigDict(extra='allow')
@@ -369,7 +374,8 @@ async def verify_tool_servers_config(request: Request, form_data: ToolServerConn
try:
if form_data.type == 'mcp':
if form_data.auth_type in ('oauth_2.1', 'oauth_2.1_static'):
discovery_urls = await get_discovery_urls(form_data.url)
oauth_server_url = form_data.info.get('oauth_server_url') if form_data.info and form_data.info.get('oauth_server_url') else form_data.url
discovery_urls = await get_discovery_urls(oauth_server_url)
for discovery_url in discovery_urls:
log.debug(f'Trying to fetch OAuth 2.1 discovery document from {discovery_url}')
async with aiohttp.ClientSession(
@@ -428,7 +434,8 @@ async def verify_tool_servers_config(request: Request, form_data: ToolServerConn
if form_data.headers and isinstance(form_data.headers, dict):
if headers is None:
headers = {}
headers.update(form_data.headers)
custom_headers = get_custom_headers(form_data.headers, user)
headers.update(custom_headers)
await client.connect(form_data.url, headers=headers)
specs = await client.list_tool_specs()
@@ -472,7 +479,8 @@ async def verify_tool_servers_config(request: Request, form_data: ToolServerConn
if form_data.headers and isinstance(form_data.headers, dict):
if headers is None:
headers = {}
headers.update(form_data.headers)
custom_headers = get_custom_headers(form_data.headers, user)
headers.update(custom_headers)
url = get_tool_server_url(form_data.url, form_data.path)
return await get_tool_server_data(url, headers=headers)

View File

@@ -62,7 +62,7 @@ from open_webui.utils.session_pool import (
)
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.headers import include_user_info_headers
from open_webui.utils.headers import include_user_info_headers, get_custom_headers
from open_webui.utils.anthropic import is_anthropic_url, get_anthropic_models
log = logging.getLogger(__name__)
@@ -215,7 +215,8 @@ async def get_headers_and_cookies(
headers['Authorization'] = f'Bearer {token}'
if config.get('headers') and isinstance(config.get('headers'), dict):
headers = {**headers, **config.get('headers')}
custom_headers = get_custom_headers(config.get('headers'), user, metadata)
headers.update(custom_headers)
return headers, cookies

View File

@@ -16,3 +16,25 @@ def include_user_info_headers(headers, user):
FORWARD_USER_INFO_HEADER_USER_EMAIL: user.email,
FORWARD_USER_INFO_HEADER_USER_ROLE: user.role,
}
def get_custom_headers(custom_headers: dict, user=None, metadata: dict = None) -> dict:
if not custom_headers or not isinstance(custom_headers, dict):
return {}
metadata = metadata or {}
template_vars = {
'{{CHAT_ID}}': metadata.get('chat_id', '') or '',
'{{MESSAGE_ID}}': metadata.get('message_id', '') or '',
'{{USER_ID}}': (user.id if user else '') or '',
'{{USER_NAME}}': (user.name if user else '') or '',
}
parsed_headers = {}
for key, value in custom_headers.items():
if not isinstance(value, str):
value = str(value)
for token, val in template_vars.items():
value = value.replace(token, val)
parsed_headers[key] = value
return parsed_headers

View File

@@ -54,7 +54,7 @@ from open_webui.env import (
FORWARD_SESSION_INFO_HEADER_MESSAGE_ID,
REDIS_KEY_PREFIX,
)
from open_webui.utils.headers import include_user_info_headers
from open_webui.utils.headers import include_user_info_headers, get_custom_headers
from open_webui.tools.builtin import (
search_web,
fetch_url,
@@ -337,8 +337,9 @@ async def get_tools(request: Request, tool_ids: list[str], user: UserModel, extr
connection_headers = tool_server_connection.get('headers', None)
if connection_headers and isinstance(connection_headers, dict):
for key, value in connection_headers.items():
headers[key] = value
metadata = extra_params.get('__metadata__', {})
custom_headers = get_custom_headers(connection_headers, user, metadata)
headers.update(custom_headers)
# Add user info headers if enabled
if ENABLE_FORWARD_USER_INFO_HEADERS and user: