Merge remote-tracking branch 'upstream/dev' into playwright

This commit is contained in:
Rory
2025-02-05 17:47:58 -06:00
95 changed files with 2173 additions and 800 deletions

View File

@@ -927,6 +927,12 @@ USER_PERMISSIONS_FEATURES_IMAGE_GENERATION = (
== "true"
)
USER_PERMISSIONS_FEATURES_CODE_INTERPRETER = (
os.environ.get("USER_PERMISSIONS_FEATURES_CODE_INTERPRETER", "True").lower()
== "true"
)
DEFAULT_USER_PERMISSIONS = {
"workspace": {
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
@@ -944,6 +950,7 @@ DEFAULT_USER_PERMISSIONS = {
"features": {
"web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH,
"image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION,
"code_interpreter": USER_PERMISSIONS_FEATURES_CODE_INTERPRETER,
},
}
@@ -1328,9 +1335,13 @@ DEFAULT_CODE_INTERPRETER_PROMPT = """
- When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.
- After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**
- If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.
- If a link is provided for an image, audio, or any file, include it in the response exactly as given to ensure the user has access to the original resource.
- All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
- **If a link to an image, audio, or any file is provided in markdown format, ALWAYS regurgitate explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
####################################
# Vector Database
####################################
@@ -1741,6 +1752,11 @@ BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
)
EXA_API_KEY = PersistentConfig(
"EXA_API_KEY",
"rag.web.search.exa_api_key",
os.getenv("EXA_API_KEY", ""),
)
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT",

View File

@@ -250,7 +250,7 @@ async def generate_function_chat_completion(
params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user)
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id)

View File

@@ -179,6 +179,7 @@ from open_webui.config import (
BING_SEARCH_V7_ENDPOINT,
BING_SEARCH_V7_SUBSCRIPTION_KEY,
BRAVE_SEARCH_API_KEY,
EXA_API_KEY,
KAGI_SEARCH_API_KEY,
MOJEEK_SEARCH_API_KEY,
GOOGLE_PSE_API_KEY,
@@ -525,6 +526,7 @@ app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
app.state.config.JINA_API_KEY = JINA_API_KEY
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
@@ -863,6 +865,7 @@ async def chat_completion(
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id)
# Check if user has access to the model
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
@@ -880,12 +883,24 @@ async def chat_completion(
"files": form_data.get("files", None),
"features": form_data.get("features", None),
"variables": form_data.get("variables", None),
"model": model_info,
**(
{"function_calling": "native"}
if form_data.get("params", {}).get("function_calling") == "native"
or (
model_info
and model_info.params.model_dump().get("function_calling")
== "native"
)
else {}
),
}
form_data["metadata"] = metadata
form_data, events = await process_chat_payload(
form_data, metadata, events = await process_chat_payload(
request, form_data, metadata, user, model
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -894,6 +909,7 @@ async def chat_completion(
try:
response = await chat_completion_handler(request, form_data, user)
return await process_chat_response(
request, response, form_data, user, events, metadata, tasks
)

View File

@@ -15,8 +15,13 @@ from langchain_core.documents import Document
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.models.users import UserModel
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
from open_webui.env import (
SRC_LOG_LEVELS,
OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -61,9 +66,7 @@ class VectorSearchRetriever(BaseRetriever):
def query_doc(
collection_name: str,
query_embedding: list[float],
k: int,
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
):
try:
result = VECTOR_DB_CLIENT.search(
@@ -259,26 +262,31 @@ def get_embedding_function(
embedding_batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
return lambda query, user=None: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
func = lambda query: generate_embeddings(
func = lambda query, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
url=url,
key=key,
user=user,
)
def generate_multiple(query, func):
def generate_multiple(query, user, func):
if isinstance(query, list):
embeddings = []
for i in range(0, len(query), embedding_batch_size):
embeddings.extend(func(query[i : i + embedding_batch_size]))
embeddings.extend(
func(query[i : i + embedding_batch_size], user=user)
)
return embeddings
else:
return func(query)
return func(query, user)
return lambda query: generate_multiple(query, func)
return lambda query, user=None: generate_multiple(query, user, func)
else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
def get_sources_from_files(
@@ -423,7 +431,11 @@ def get_model_path(model: str, update_model: bool = False):
def generate_openai_batch_embeddings(
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
r = requests.post(
@@ -431,6 +443,16 @@ def generate_openai_batch_embeddings(
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json={"input": texts, "model": model},
)
@@ -446,7 +468,7 @@ def generate_openai_batch_embeddings(
def generate_ollama_batch_embeddings(
model: str, texts: list[str], url: str, key: str = ""
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
) -> Optional[list[list[float]]]:
try:
r = requests.post(
@@ -454,6 +476,16 @@ def generate_ollama_batch_embeddings(
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
json={"input": texts, "model": model},
)
@@ -472,22 +504,29 @@ def generate_ollama_batch_embeddings(
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
user = kwargs.get("user")
if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
**{"model": model, "texts": text, "url": url, "key": key}
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
)
else:
embeddings = generate_ollama_batch_embeddings(
**{"model": model, "texts": [text], "url": url, "key": key}
**{
"model": model,
"texts": [text],
"url": url,
"key": key,
"user": user,
}
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, url, key)
embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
else:
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
return embeddings[0] if isinstance(text, str) else embeddings

View File

@@ -0,0 +1,76 @@
import logging
from dataclasses import dataclass
from typing import Optional
import requests
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.web.main import SearchResult
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
EXA_API_BASE = "https://api.exa.ai"
@dataclass
class ExaResult:
url: str
title: str
text: str
def search_exa(
api_key: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Exa Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Exa Search API key
query (str): The query to search for
count (int): Number of results to return
filter_list (Optional[list[str]]): List of domains to filter results by
"""
log.info(f"Searching with Exa for query: {query}")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {
"query": query,
"numResults": count or 5,
"includeDomains": filter_list,
"contents": {"text": True, "highlights": True},
"type": "auto", # Use the auto search type (keyword or neural)
}
try:
response = requests.post(
f"{EXA_API_BASE}/search", headers=headers, json=payload
)
response.raise_for_status()
data = response.json()
results = []
for result in data["results"]:
results.append(
ExaResult(
url=result["url"],
title=result["title"],
text=result["text"],
)
)
log.info(f"Found {len(results)} results")
return [
SearchResult(
link=result.url,
title=result.title,
snippet=result.text,
)
for result in results
]
except Exception as e:
log.error(f"Error searching Exa: {e}")
return []

View File

@@ -48,6 +48,7 @@ def validate_url(url: Union[str, Sequence[str]]):
else:
return False
def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
valid_urls = []
for u in url:
@@ -57,6 +58,7 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
except ValueError:
continue
return valid_urls
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)

View File

@@ -71,7 +71,7 @@ def upload_file(
)
try:
process_file(request, ProcessFileForm(file_id=id))
process_file(request, ProcessFileForm(file_id=id), user=user)
file_item = Files.get_file_by_id(id=id)
except Exception as e:
log.exception(e)
@@ -193,7 +193,9 @@ async def update_file_data_content_by_id(
if file and (file.user_id == user.id or user.role == "admin"):
try:
process_file(
request, ProcessFileForm(file_id=id, content=form_data.content)
request,
ProcessFileForm(file_id=id, content=form_data.content),
user=user,
)
file = Files.get_file_by_id(id=id)
except Exception as e:

View File

@@ -289,7 +289,9 @@ def add_file_to_knowledge_by_id(
# Add content to the vector database
try:
process_file(
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
request,
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
user=user,
)
except Exception as e:
log.debug(e)
@@ -372,7 +374,9 @@ def update_file_from_knowledge_by_id(
# Add content to the vector database
try:
process_file(
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
request,
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
user=user,
)
except Exception as e:
raise HTTPException(

View File

@@ -57,7 +57,7 @@ async def add_memory(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
"metadata": {"created_at": memory.created_at},
}
],
@@ -82,7 +82,7 @@ async def query_memory(
):
results = VECTOR_DB_CLIENT.search(
collection_name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
limit=form_data.k,
)
@@ -105,7 +105,7 @@ async def reset_memory_from_vector_db(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,
@@ -160,7 +160,9 @@ async def update_memory_by_id(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user
),
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,

View File

@@ -939,6 +939,7 @@ async def generate_completion(
class ChatMessage(BaseModel):
role: str
content: str
tool_calls: Optional[list[dict]] = None
images: Optional[list[str]] = None
@@ -950,6 +951,7 @@ class GenerateChatCompletionForm(BaseModel):
template: Optional[str] = None
stream: Optional[bool] = True
keep_alive: Optional[Union[int, str]] = None
tools: Optional[list[dict]] = None
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
@@ -1005,7 +1007,7 @@ async def generate_chat_completion(
payload["options"] = apply_model_params_to_body_ollama(
params, payload["options"]
)
payload = apply_model_system_prompt_to_body(params, payload, metadata)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user":
@@ -1158,6 +1160,8 @@ async def generate_openai_chat_completion(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
metadata = form_data.pop("metadata", None)
try:
completion_form = OpenAIChatCompletionForm(**form_data)
except Exception as e:
@@ -1184,7 +1188,7 @@ async def generate_openai_chat_completion(
if params:
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
# Check if user has access to the model
if user.role == "user":

View File

@@ -566,7 +566,7 @@ async def generate_chat_completion(
params = model_info.params.model_dump()
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, metadata)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user":

View File

@@ -55,6 +55,7 @@ from open_webui.retrieval.web.serply import search_serply
from open_webui.retrieval.web.serpstack import search_serpstack
from open_webui.retrieval.web.tavily import search_tavily
from open_webui.retrieval.web.bing import search_bing
from open_webui.retrieval.web.exa import search_exa
from open_webui.retrieval.utils import (
@@ -388,6 +389,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"jina_api_key": request.app.state.config.JINA_API_KEY,
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
@@ -436,6 +438,7 @@ class WebSearchConfig(BaseModel):
jina_api_key: Optional[str] = None
bing_search_v7_endpoint: Optional[str] = None
bing_search_v7_subscription_key: Optional[str] = None
exa_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
@@ -542,6 +545,8 @@ async def update_rag_config(
form_data.web.search.bing_search_v7_subscription_key
)
request.app.state.config.EXA_API_KEY = form_data.web.search.exa_api_key
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = (
form_data.web.search.result_count
)
@@ -591,6 +596,7 @@ async def update_rag_config(
"jina_api_key": request.app.state.config.JINA_API_KEY,
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
@@ -660,6 +666,7 @@ def save_docs_to_vector_db(
overwrite: bool = False,
split: bool = True,
add: bool = False,
user=None,
) -> bool:
def _get_docs_info(docs: list[Document]) -> str:
docs_info = set()
@@ -775,7 +782,7 @@ def save_docs_to_vector_db(
)
embeddings = embedding_function(
list(map(lambda x: x.replace("\n", " "), texts))
list(map(lambda x: x.replace("\n", " "), texts)), user=user
)
items = [
@@ -933,6 +940,7 @@ def process_file(
"hash": hash,
},
add=(True if form_data.collection_name else False),
user=user,
)
if result:
@@ -990,7 +998,7 @@ def process_text(
text_content = form_data.content
log.debug(f"text_content: {text_content}")
result = save_docs_to_vector_db(request, docs, collection_name)
result = save_docs_to_vector_db(request, docs, collection_name, user=user)
if result:
return {
"status": True,
@@ -1023,7 +1031,9 @@ def process_youtube_video(
content = " ".join([doc.page_content for doc in docs])
log.debug(f"text_content: {content}")
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
save_docs_to_vector_db(
request, docs, collection_name, overwrite=True, user=user
)
return {
"status": True,
@@ -1064,7 +1074,9 @@ def process_web(
content = " ".join([doc.page_content for doc in docs])
log.debug(f"text_content: {content}")
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
save_docs_to_vector_db(
request, docs, collection_name, overwrite=True, user=user
)
return {
"status": True,
@@ -1099,6 +1111,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
- SERPER_API_KEY
- SERPLY_API_KEY
- TAVILY_API_KEY
- EXA_API_KEY
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
Args:
query (str): The query to search for
@@ -1233,6 +1246,13 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "exa":
return search_exa(
request.app.state.config.EXA_API_KEY,
query,
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No search engine API key found in environment variables")
@@ -1273,7 +1293,9 @@ async def process_web_search(
)
docs = [doc async for doc in loader.alazy_load()]
# docs = loader.load()
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
save_docs_to_vector_db(
request, docs, collection_name, overwrite=True, user=user
)
return {
"status": True,
@@ -1307,7 +1329,9 @@ def query_doc_handler(
return query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=request.app.state.EMBEDDING_FUNCTION,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
query, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
r=(
@@ -1315,12 +1339,16 @@ def query_doc_handler(
if form_data.r
else request.app.state.config.RELEVANCE_THRESHOLD
),
user=user,
)
else:
return query_doc(
collection_name=form_data.collection_name,
query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query),
query_embedding=request.app.state.EMBEDDING_FUNCTION(
form_data.query, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
user=user,
)
except Exception as e:
log.exception(e)
@@ -1349,7 +1377,9 @@ def query_collection_handler(
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
queries=[form_data.query],
embedding_function=request.app.state.EMBEDDING_FUNCTION,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
query, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
r=(
@@ -1362,7 +1392,9 @@ def query_collection_handler(
return query_collection(
collection_names=form_data.collection_names,
queries=[form_data.query],
embedding_function=request.app.state.EMBEDDING_FUNCTION,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
query, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
)
@@ -1510,6 +1542,7 @@ def process_files_batch(
docs=all_docs,
collection_name=collection_name,
add=True,
user=user,
)
# Update all files with collection name

View File

@@ -79,6 +79,7 @@ class ChatPermissions(BaseModel):
class FeaturesPermissions(BaseModel):
web_search: bool = True
image_generation: bool = True
code_interpreter: bool = True
class UserPermissions(BaseModel):

View File

@@ -1,6 +1,8 @@
import time
import logging
import sys
import os
import base64
import asyncio
from aiocache import cached
@@ -10,6 +12,7 @@ import json
import html
import inspect
import re
import ast
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
@@ -55,6 +58,7 @@ from open_webui.utils.task import (
tools_function_calling_generation_template,
)
from open_webui.utils.misc import (
deep_update,
get_message_list,
add_or_update_system_message,
add_or_update_user_message,
@@ -69,6 +73,7 @@ from open_webui.utils.plugin import load_function_module_by_id
from open_webui.tasks import create_task
from open_webui.config import (
CACHE_DIR,
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
DEFAULT_CODE_INTERPRETER_PROMPT,
)
@@ -180,7 +185,7 @@ async def chat_completion_filter_functions_handler(request, body, model, extra_p
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, extra_params: dict
request: Request, body: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]:
content = None
@@ -215,35 +220,15 @@ async def chat_completion_tools_handler(
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
# If tool_ids field is present, call the functions
metadata = body.get("metadata", {})
tool_ids = metadata.get("tool_ids", None)
log.debug(f"{tool_ids=}")
if not tool_ids:
return body, {}
skip_files = False
sources = []
task_model_id = get_task_model_id(
body["model"],
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
tools = get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models[task_model_id],
"__messages__": body["messages"],
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
skip_files = False
sources = []
specs = [tool["spec"] for tool in tools.values()]
tools_specs = json.dumps(specs)
@@ -278,6 +263,8 @@ async def chat_completion_tools_handler(
result = json.loads(content)
async def tool_call_handler(tool_call):
nonlocal skip_files
log.debug(f"{tool_call=}")
tool_function_name = tool_call.get("name", None)
@@ -418,7 +405,7 @@ async def chat_web_search_handler(
},
}
)
return
return form_data
searchQuery = queries[0]
@@ -641,7 +628,9 @@ async def chat_completion_files_handler(
lambda: get_sources_from_files(
files=files,
queries=queries,
embedding_function=request.app.state.EMBEDDING_FUNCTION,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
query, user=user
),
k=request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
r=request.app.state.config.RELEVANCE_THRESHOLD,
@@ -693,6 +682,7 @@ def apply_params_to_form_data(form_data, model):
async def process_chat_payload(request, form_data, metadata, user, model):
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
@@ -715,6 +705,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
# Initialize events to store additional event to be sent to the client
# Initialize contexts and citation
models = request.app.state.MODELS
task_model_id = get_task_model_id(
form_data["model"],
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
events = []
sources = []
@@ -799,13 +795,41 @@ async def process_chat_payload(request, form_data, metadata, user, model):
}
form_data["metadata"] = metadata
try:
form_data, flags = await chat_completion_tools_handler(
request, form_data, user, models, extra_params
tool_ids = metadata.get("tool_ids", None)
log.debug(f"{tool_ids=}")
if tool_ids:
# If tool_ids field is present, then get the tools
tools = get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models[task_model_id],
"__messages__": form_data["messages"],
"__files__": metadata.get("files", []),
},
)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
log.info(f"{tools=}")
if metadata.get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools.values()
]
else:
# If the function calling is not native, then call the tools function calling handler
try:
form_data, flags = await chat_completion_tools_handler(
request, form_data, user, models, tools
)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
try:
form_data, flags = await chat_completion_files_handler(request, form_data, user)
@@ -821,11 +845,11 @@ async def process_chat_payload(request, form_data, metadata, user, model):
if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]):
metadata = source.get("metadata")
doc_metadata = source.get("metadata")
doc_source_id = None
if metadata:
doc_source_id = metadata[doc_idx].get("source", source_id)
if doc_metadata:
doc_source_id = doc_metadata[doc_idx].get("source", source_id)
if source_id:
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
@@ -882,7 +906,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
}
)
return form_data, events
return form_data, metadata, events
async def process_chat_response(
@@ -1100,6 +1124,40 @@ async def process_chat_response(
for block in content_blocks:
if block["type"] == "text":
content = f"{content}{block['content'].strip()}\n"
elif block["type"] == "tool_calls":
attributes = block.get("attributes", {})
block_content = block.get("content", [])
results = block.get("results", [])
if results:
result_display_content = ""
for result in results:
tool_call_id = result.get("tool_call_id", "")
tool_name = ""
for tool_call in block_content:
if tool_call.get("id", "") == tool_call_id:
tool_name = tool_call.get("function", {}).get(
"name", ""
)
break
result_display_content = f"{result_display_content}\n> {tool_name}: {result.get('content', '')}"
if not raw:
content = f'{content}\n<details type="tool_calls" done="true" content="{html.escape(json.dumps(block_content))}" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n{result_display_content}\n</details>\n'
else:
tool_calls_display_content = ""
for tool_call in block_content:
tool_calls_display_content = f"{tool_calls_display_content}\n> Executing {tool_call.get('function', {}).get('name', '')}"
if not raw:
content = f'{content}\n<details type="tool_calls" done="false" content="{html.escape(json.dumps(block_content))}">\n<summary>Tool Executing...</summary>\n{tool_calls_display_content}\n</details>\n'
elif block["type"] == "reasoning":
reasoning_display_content = "\n".join(
(f"> {line}" if not line.startswith(">") else line)
@@ -1108,16 +1166,16 @@ async def process_chat_response(
reasoning_duration = block.get("duration", None)
if reasoning_duration:
if reasoning_duration is not None:
if raw:
content = f'{content}<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
else:
content = f'{content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
else:
if raw:
content = f'{content}<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
else:
content = f'{content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
elif block["type"] == "code_interpreter":
attributes = block.get("attributes", {})
@@ -1128,20 +1186,20 @@ async def process_chat_response(
output = html.escape(json.dumps(output))
if raw:
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
else:
content = f'{content}<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
content = f'{content}\n<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
else:
if raw:
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
else:
content = f'{content}<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
content = f'{content}\n<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
else:
block_content = str(block["content"]).strip()
content = f"{content}{block['type']}: {block_content}\n"
return content
return content.strip()
def tag_content_handler(content_type, tags, content, content_blocks):
end_flag = False
@@ -1149,6 +1207,8 @@ async def process_chat_response(
def extract_attributes(tag_content):
"""Extract attributes from a tag if they exist."""
attributes = {}
if not tag_content: # Ensure tag_content is not None
return attributes
# Match attributes in the format: key="value" (ignores single quotes for simplicity)
matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content)
for key, value in matches:
@@ -1158,17 +1218,35 @@ async def process_chat_response(
if content_blocks[-1]["type"] == "text":
for tag in tags:
# Match start tag e.g., <tag> or <tag attr="value">
start_tag_pattern = rf"<{tag}(.*?)>"
start_tag_pattern = rf"<{tag}(\s.*?)?>"
match = re.search(start_tag_pattern, content)
if match:
# Extract attributes in the tag (if present)
attributes = extract_attributes(match.group(1))
attr_content = (
match.group(1) if match.group(1) else ""
) # Ensure it's not None
attributes = extract_attributes(
attr_content
) # Extract attributes safely
# Capture everything before and after the matched tag
before_tag = content[
: match.start()
] # Content before opening tag
after_tag = content[
match.end() :
] # Content after opening tag
# Remove the start tag from the currently handling text block
content_blocks[-1]["content"] = content_blocks[-1][
"content"
].replace(match.group(0), "")
if before_tag:
content_blocks[-1]["content"] = before_tag
if not content_blocks[-1]["content"]:
content_blocks.pop()
# Append the new block
content_blocks.append(
{
@@ -1179,52 +1257,100 @@ async def process_chat_response(
"started_at": time.time(),
}
)
if after_tag:
content_blocks[-1]["content"] = after_tag
break
elif content_blocks[-1]["type"] == content_type:
tag = content_blocks[-1]["tag"]
# Match end tag e.g., </tag>
end_tag_pattern = rf"</{tag}>"
# Check if the content has the end tag
if re.search(end_tag_pattern, content):
end_flag = True
block_content = content_blocks[-1]["content"]
# Strip start and end tags from the content
start_tag_pattern = rf"<{tag}(.*?)>"
block_content = re.sub(
start_tag_pattern, "", block_content
).strip()
block_content = re.sub(
end_tag_pattern, "", block_content
).strip()
end_tag_regex = re.compile(end_tag_pattern, re.DOTALL)
split_content = end_tag_regex.split(block_content, maxsplit=1)
# Content inside the tag
block_content = (
split_content[0].strip() if split_content else ""
)
# Leftover content (everything after `</tag>`)
leftover_content = (
split_content[1].strip() if len(split_content) > 1 else ""
)
if block_content:
end_flag = True
content_blocks[-1]["content"] = block_content
content_blocks[-1]["ended_at"] = time.time()
content_blocks[-1]["duration"] = int(
content_blocks[-1]["ended_at"]
- content_blocks[-1]["started_at"]
)
# Reset the content_blocks by appending a new text block
content_blocks.append(
{
"type": "text",
"content": "",
}
)
# Clean processed content
content = re.sub(
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
"",
content,
flags=re.DOTALL,
)
if content_type != "code_interpreter":
if leftover_content:
content_blocks.append(
{
"type": "text",
"content": leftover_content,
}
)
else:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
else:
# Remove the block if content is empty
content_blocks.pop()
if leftover_content:
content_blocks.append(
{
"type": "text",
"content": leftover_content,
}
)
else:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
# Clean processed content
content = re.sub(
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
"",
content,
flags=re.DOTALL,
)
return content, content_blocks, end_flag
message = Chats.get_message_by_id_and_message_id(
metadata["chat_id"], metadata["message_id"]
)
tool_calls = []
content = message.get("content", "") if message else ""
content_blocks = [
{
@@ -1235,9 +1361,18 @@ async def process_chat_response(
# We might want to disable this by default
DETECT_REASONING = True
DETECT_CODE_INTERPRETER = True
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
"code_interpreter", False
)
reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
reasoning_tags = [
"think",
"thinking",
"reason",
"reasoning",
"thought",
"Thought",
]
code_interpreter_tags = ["code_interpreter"]
try:
@@ -1262,6 +1397,8 @@ async def process_chat_response(
nonlocal content
nonlocal content_blocks
response_tool_calls = []
async for line in response.body_iterator:
line = line.decode("utf-8") if isinstance(line, bytes) else line
data = line
@@ -1294,10 +1431,54 @@ async def process_chat_response(
if not choices:
continue
value = choices[0].get("delta", {}).get("content")
delta = choices[0].get("delta", {})
delta_tool_calls = delta.get("tool_calls", None)
if delta_tool_calls:
for delta_tool_call in delta_tool_calls:
tool_call_index = delta_tool_call.get("index")
if tool_call_index is not None:
if (
len(response_tool_calls)
<= tool_call_index
):
response_tool_calls.append(
delta_tool_call
)
else:
delta_name = delta_tool_call.get(
"function", {}
).get("name")
delta_arguments = delta_tool_call.get(
"function", {}
).get("arguments")
if delta_name:
response_tool_calls[
tool_call_index
]["function"]["name"] += delta_name
if delta_arguments:
response_tool_calls[
tool_call_index
]["function"][
"arguments"
] += delta_arguments
value = delta.get("content")
if value:
content = f"{content}{value}"
if not content_blocks:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
content_blocks[-1]["content"] = (
content_blocks[-1]["content"] + value
)
@@ -1357,14 +1538,46 @@ async def process_chat_response(
log.debug("Error: ", e)
continue
# Clean up the last text block
if content_blocks[-1]["type"] == "text":
content_blocks[-1]["content"] = content_blocks[-1][
"content"
].strip()
if content_blocks:
# Clean up the last text block
if content_blocks[-1]["type"] == "text":
content_blocks[-1]["content"] = content_blocks[-1][
"content"
].strip()
if not content_blocks[-1]["content"]:
content_blocks.pop()
if not content_blocks[-1]["content"]:
content_blocks.pop()
if not content_blocks:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
if response_tool_calls:
tool_calls.append(response_tool_calls)
if response.background:
await response.background()
await stream_body_handler(response)
MAX_TOOL_CALL_RETRIES = 5
tool_call_retries = 0
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
tool_call_retries += 1
response_tool_calls = tool_calls.pop(0)
content_blocks.append(
{
"type": "tool_calls",
"content": response_tool_calls,
}
)
await event_emitter(
{
@@ -1375,37 +1588,54 @@ async def process_chat_response(
}
)
if response.background:
await response.background()
tools = metadata.get("tools", {})
await stream_body_handler(response)
results = []
for tool_call in response_tool_calls:
print("\n\n" + str(tool_call) + "\n\n")
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("function", {}).get("name", "")
MAX_RETRIES = 5
retries = 0
while (
content_blocks[-1]["type"] == "code_interpreter"
and retries < MAX_RETRIES
):
retries += 1
log.debug(f"Attempt count: {retries}")
output = ""
try:
if content_blocks[-1]["attributes"].get("type") == "code":
output = await event_caller(
{
"type": "execute:python",
"data": {
"id": str(uuid4()),
"code": content_blocks[-1]["content"],
},
}
tool_function_params = {}
try:
# json.loads cannot be used because some models do not produce valid JSON
tool_function_params = ast.literal_eval(
tool_call.get("function", {}).get("arguments", "{}")
)
except Exception as e:
output = str(e)
except Exception as e:
log.debug(e)
tool_result = None
if tool_name in tools:
tool = tools[tool_name]
spec = tool.get("spec", {})
try:
required_params = spec.get("parameters", {}).get(
"required", []
)
tool_function = tool["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in required_params
}
tool_result = await tool_function(
**tool_function_params
)
except Exception as e:
tool_result = str(e)
results.append(
{
"tool_call_id": tool_call_id,
"content": tool_result,
}
)
content_blocks[-1]["results"] = results
content_blocks[-1]["output"] = output
content_blocks.append(
{
"type": "text",
@@ -1435,7 +1665,16 @@ async def process_chat_response(
"content": serialize_content_blocks(
content_blocks, raw=True
),
"tool_calls": response_tool_calls,
},
*[
{
"role": "tool",
"tool_call_id": result["tool_call_id"],
"content": result["content"],
}
for result in results
],
],
},
user,
@@ -1449,6 +1688,110 @@ async def process_chat_response(
log.debug(e)
break
if DETECT_CODE_INTERPRETER:
MAX_RETRIES = 5
retries = 0
while (
content_blocks[-1]["type"] == "code_interpreter"
and retries < MAX_RETRIES
):
retries += 1
log.debug(f"Attempt count: {retries}")
output = ""
try:
if content_blocks[-1]["attributes"].get("type") == "code":
output = await event_caller(
{
"type": "execute:python",
"data": {
"id": str(uuid4()),
"code": content_blocks[-1]["content"],
},
}
)
if isinstance(output, dict):
stdout = output.get("stdout", "")
if stdout:
stdoutLines = stdout.split("\n")
for idx, line in enumerate(stdoutLines):
if "data:image/png;base64" in line:
id = str(uuid4())
# ensure the path exists
os.makedirs(
os.path.join(CACHE_DIR, "images"),
exist_ok=True,
)
image_path = os.path.join(
CACHE_DIR,
f"images/{id}.png",
)
with open(image_path, "wb") as f:
f.write(
base64.b64decode(
line.split(",")[1]
)
)
stdoutLines[idx] = (
f"![Output Image {idx}](/cache/images/{id}.png)"
)
output["stdout"] = "\n".join(stdoutLines)
except Exception as e:
output = str(e)
content_blocks[-1]["output"] = output
content_blocks.append(
{
"type": "text",
"content": "",
}
)
await event_emitter(
{
"type": "chat:completion",
"data": {
"content": serialize_content_blocks(content_blocks),
},
}
)
try:
res = await generate_chat_completion(
request,
{
"model": model_id,
"stream": True,
"messages": [
*form_data["messages"],
{
"role": "assistant",
"content": serialize_content_blocks(
content_blocks, raw=True
),
},
],
},
user,
)
if isinstance(res, StreamingResponse):
await stream_body_handler(res)
else:
break
except Exception as e:
log.debug(e)
break
title = Chats.get_chat_title_by_id(metadata["chat_id"])
data = {
"done": True,

View File

@@ -7,6 +7,18 @@ from pathlib import Path
from typing import Callable, Optional
import collections.abc
def deep_update(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = deep_update(d.get(k, {}), v)
else:
d[k] = v
return d
def get_message_list(messages, message_id):
"""
Reconstructs a list of messages in order up to the specified message_id.
@@ -179,15 +191,25 @@ def openai_chat_message_template(model: str):
def openai_chat_chunk_message_template(
model: str, message: Optional[str] = None, usage: Optional[dict] = None
model: str,
content: Optional[str] = None,
tool_calls: Optional[list[dict]] = None,
usage: Optional[dict] = None,
) -> dict:
template = openai_chat_message_template(model)
template["object"] = "chat.completion.chunk"
if message:
template["choices"][0]["delta"] = {"content": message}
else:
template["choices"][0]["index"] = 0
template["choices"][0]["delta"] = {}
if content:
template["choices"][0]["delta"]["content"] = content
if tool_calls:
template["choices"][0]["delta"]["tool_calls"] = tool_calls
if not content and not tool_calls:
template["choices"][0]["finish_reason"] = "stop"
template["choices"][0]["delta"] = {}
if usage:
template["usage"] = usage

View File

@@ -1,4 +1,4 @@
from open_webui.utils.task import prompt_variables_template
from open_webui.utils.task import prompt_template, prompt_variables_template
from open_webui.utils.misc import (
add_or_update_system_message,
)
@@ -8,16 +8,28 @@ from typing import Callable, Optional
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(
params: dict, form_data: dict, metadata: Optional[dict] = None
params: dict, form_data: dict, metadata: Optional[dict] = None, user=None
) -> dict:
system = params.get("system", None)
if not system:
return form_data
# Legacy (API Usage)
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
# Metadata (WebUI Usage)
if metadata:
print("apply_model_system_prompt_to_body: metadata", metadata)
variables = metadata.get("variables", {})
system = prompt_variables_template(system, variables)
if variables:
system = prompt_variables_template(system, variables)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
@@ -154,6 +166,9 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
)
ollama_payload["stream"] = openai_payload.get("stream", False)
if "tools" in openai_payload:
ollama_payload["tools"] = openai_payload["tools"]
if "format" in openai_payload:
ollama_payload["format"] = openai_payload["format"]

View File

@@ -1,4 +1,5 @@
import json
from uuid import uuid4
from open_webui.utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
@@ -60,6 +61,23 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
model = data.get("model", "ollama")
message_content = data.get("message", {}).get("content", "")
tool_calls = data.get("message", {}).get("tool_calls", None)
openai_tool_calls = None
if tool_calls:
openai_tool_calls = []
for tool_call in tool_calls:
openai_tool_call = {
"index": tool_call.get("index", 0),
"id": tool_call.get("id", f"call_{str(uuid4())}"),
"type": "function",
"function": {
"name": tool_call.get("function", {}).get("name", ""),
"arguments": f"{tool_call.get('function', {}).get('arguments', {})}",
},
}
openai_tool_calls.append(openai_tool_call)
done = data.get("done", False)
usage = None
@@ -105,7 +123,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
}
data = openai_chat_chunk_message_template(
model, message_content if not done else None, usage
model, message_content if not done else None, openai_tool_calls, usage
)
line = f"data: {json.dumps(data)}\n\n"

View File

@@ -61,6 +61,12 @@ def get_tools(
)
for spec in tools.specs:
# TODO: Fix hack for OpenAI API
# Some times breaks OpenAI but others don't. Leaving the comment
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
# Remove internal parameters
spec["parameters"]["properties"] = {
key: val
@@ -73,6 +79,13 @@ def get_tools(
# convert to function that takes only model params and inserts custom params
original_func = getattr(module, function_name)
callable = apply_extra_params_to_tool_function(original_func, extra_params)
if callable.__doc__ and callable.__doc__.strip() != "":
s = re.split(":(param|return)", callable.__doc__, 1)
spec["description"] = s[0]
else:
spec["description"] = function_name
# TODO: This needs to be a pydantic model
tool_dict = {
"toolkit_id": tool_id,