mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-18 05:05:09 +02:00
refac
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Optional
|
||||
|
||||
from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.headers import get_custom_headers
|
||||
from open_webui.config import get_config, save_config, async_save_config
|
||||
from open_webui.config import BannerModel
|
||||
|
||||
@@ -103,6 +104,7 @@ class OAuthClientRegistrationForm(BaseModel):
|
||||
client_id: str
|
||||
client_name: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
oauth_server_url: Optional[str] = None
|
||||
|
||||
|
||||
@router.post('/oauth/clients/register')
|
||||
@@ -117,18 +119,20 @@ async def register_oauth_client(
|
||||
if type:
|
||||
oauth_client_id = f'{type}:{form_data.client_id}'
|
||||
|
||||
oauth_server_url = form_data.oauth_server_url if form_data.oauth_server_url else form_data.url
|
||||
|
||||
if form_data.client_secret:
|
||||
# Static credentials: skip dynamic registration, build from provided credentials
|
||||
oauth_client_info = await get_oauth_client_info_with_static_credentials(
|
||||
request,
|
||||
oauth_client_id,
|
||||
form_data.url,
|
||||
oauth_server_url,
|
||||
oauth_client_id=form_data.client_id,
|
||||
oauth_client_secret=form_data.client_secret,
|
||||
)
|
||||
else:
|
||||
oauth_client_info = await get_oauth_client_info_with_dynamic_client_registration(
|
||||
request, oauth_client_id, form_data.url
|
||||
request, oauth_client_id, oauth_server_url
|
||||
)
|
||||
return {
|
||||
'status': True,
|
||||
@@ -155,6 +159,7 @@ class ToolServerConnection(BaseModel):
|
||||
headers: Optional[dict | str] = None
|
||||
key: Optional[str]
|
||||
config: Optional[dict]
|
||||
info: Optional[dict] = None
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
@@ -369,7 +374,8 @@ async def verify_tool_servers_config(request: Request, form_data: ToolServerConn
|
||||
try:
|
||||
if form_data.type == 'mcp':
|
||||
if form_data.auth_type in ('oauth_2.1', 'oauth_2.1_static'):
|
||||
discovery_urls = await get_discovery_urls(form_data.url)
|
||||
oauth_server_url = form_data.info.get('oauth_server_url') if form_data.info and form_data.info.get('oauth_server_url') else form_data.url
|
||||
discovery_urls = await get_discovery_urls(oauth_server_url)
|
||||
for discovery_url in discovery_urls:
|
||||
log.debug(f'Trying to fetch OAuth 2.1 discovery document from {discovery_url}')
|
||||
async with aiohttp.ClientSession(
|
||||
@@ -428,7 +434,8 @@ async def verify_tool_servers_config(request: Request, form_data: ToolServerConn
|
||||
if form_data.headers and isinstance(form_data.headers, dict):
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.update(form_data.headers)
|
||||
custom_headers = get_custom_headers(form_data.headers, user)
|
||||
headers.update(custom_headers)
|
||||
|
||||
await client.connect(form_data.url, headers=headers)
|
||||
specs = await client.list_tool_specs()
|
||||
@@ -472,7 +479,8 @@ async def verify_tool_servers_config(request: Request, form_data: ToolServerConn
|
||||
if form_data.headers and isinstance(form_data.headers, dict):
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.update(form_data.headers)
|
||||
custom_headers = get_custom_headers(form_data.headers, user)
|
||||
headers.update(custom_headers)
|
||||
|
||||
url = get_tool_server_url(form_data.url, form_data.path)
|
||||
return await get_tool_server_data(url, headers=headers)
|
||||
|
||||
@@ -62,7 +62,7 @@ from open_webui.utils.session_pool import (
|
||||
)
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.headers import include_user_info_headers
|
||||
from open_webui.utils.headers import include_user_info_headers, get_custom_headers
|
||||
from open_webui.utils.anthropic import is_anthropic_url, get_anthropic_models
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -215,7 +215,8 @@ async def get_headers_and_cookies(
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
|
||||
if config.get('headers') and isinstance(config.get('headers'), dict):
|
||||
headers = {**headers, **config.get('headers')}
|
||||
custom_headers = get_custom_headers(config.get('headers'), user, metadata)
|
||||
headers.update(custom_headers)
|
||||
|
||||
return headers, cookies
|
||||
|
||||
|
||||
@@ -16,3 +16,25 @@ def include_user_info_headers(headers, user):
|
||||
FORWARD_USER_INFO_HEADER_USER_EMAIL: user.email,
|
||||
FORWARD_USER_INFO_HEADER_USER_ROLE: user.role,
|
||||
}
|
||||
|
||||
def get_custom_headers(custom_headers: dict, user=None, metadata: dict = None) -> dict:
|
||||
if not custom_headers or not isinstance(custom_headers, dict):
|
||||
return {}
|
||||
|
||||
metadata = metadata or {}
|
||||
template_vars = {
|
||||
'{{CHAT_ID}}': metadata.get('chat_id', '') or '',
|
||||
'{{MESSAGE_ID}}': metadata.get('message_id', '') or '',
|
||||
'{{USER_ID}}': (user.id if user else '') or '',
|
||||
'{{USER_NAME}}': (user.name if user else '') or '',
|
||||
}
|
||||
|
||||
parsed_headers = {}
|
||||
for key, value in custom_headers.items():
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
for token, val in template_vars.items():
|
||||
value = value.replace(token, val)
|
||||
parsed_headers[key] = value
|
||||
|
||||
return parsed_headers
|
||||
|
||||
@@ -54,7 +54,7 @@ from open_webui.env import (
|
||||
FORWARD_SESSION_INFO_HEADER_MESSAGE_ID,
|
||||
REDIS_KEY_PREFIX,
|
||||
)
|
||||
from open_webui.utils.headers import include_user_info_headers
|
||||
from open_webui.utils.headers import include_user_info_headers, get_custom_headers
|
||||
from open_webui.tools.builtin import (
|
||||
search_web,
|
||||
fetch_url,
|
||||
@@ -337,8 +337,9 @@ async def get_tools(request: Request, tool_ids: list[str], user: UserModel, extr
|
||||
|
||||
connection_headers = tool_server_connection.get('headers', None)
|
||||
if connection_headers and isinstance(connection_headers, dict):
|
||||
for key, value in connection_headers.items():
|
||||
headers[key] = value
|
||||
metadata = extra_params.get('__metadata__', {})
|
||||
custom_headers = get_custom_headers(connection_headers, user, metadata)
|
||||
headers.update(custom_headers)
|
||||
|
||||
# Add user info headers if enabled
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
|
||||
Reference in New Issue
Block a user