enh: load tool by url

This commit is contained in:
Timothy Jaeryang Baek
2025-05-29 02:08:54 +04:00
parent 4461122a0e
commit 85a384fab5
6 changed files with 244 additions and 12 deletions

View File

@@ -2,6 +2,9 @@ import logging
from pathlib import Path
from typing import Optional
import time
import re
import aiohttp
from pydantic import BaseModel, HttpUrl
from open_webui.models.tools import (
ToolForm,
@@ -21,6 +24,7 @@ from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.tools import get_tool_servers_data
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
@@ -95,6 +99,81 @@ async def get_tool_list(user=Depends(get_verified_user)):
return tools
############################
# LoadFunctionFromLink
############################
class LoadUrlForm(BaseModel):
url: HttpUrl
def github_url_to_raw_url(url: str) -> str:
# Handle 'tree' (folder) URLs (add main.py at the end)
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
if m1:
org, repo, branch, path = m1.groups()
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
# Handle 'blob' (file) URLs
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
if m2:
org, repo, branch, path = m2.groups()
return (
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
)
# No match; return as-is
return url
@router.post("/load/url", response_model=Optional[dict])
async def load_tool_from_url(
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
):
# NOTE: This is NOT a SSRF vulnerability:
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
# and does NOT accept untrusted user input. Access is enforced by authentication.
url = str(form_data.url)
if not url:
raise HTTPException(status_code=400, detail="Please enter a valid URL")
url = github_url_to_raw_url(url)
url_parts = url.rstrip("/").split("/")
file_name = url_parts[-1]
tool_name = (
file_name[:-3]
if (
file_name.endswith(".py")
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
)
else url_parts[-2] if len(url_parts) > 1 else "function"
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status, detail="Failed to fetch the tool"
)
data = await resp.text()
if not data:
raise HTTPException(
status_code=400, detail="No data received from the URL"
)
return {
"name": tool_name,
"content": data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
############################
# ExportTools
############################