mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
Merge remote-tracking branch 'upstream/dev' into playwright
This commit is contained in:
@@ -927,6 +927,12 @@ USER_PERMISSIONS_FEATURES_IMAGE_GENERATION = (
|
||||
== "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_FEATURES_CODE_INTERPRETER = (
|
||||
os.environ.get("USER_PERMISSIONS_FEATURES_CODE_INTERPRETER", "True").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_USER_PERMISSIONS = {
|
||||
"workspace": {
|
||||
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
|
||||
@@ -944,6 +950,7 @@ DEFAULT_USER_PERMISSIONS = {
|
||||
"features": {
|
||||
"web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH,
|
||||
"image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION,
|
||||
"code_interpreter": USER_PERMISSIONS_FEATURES_CODE_INTERPRETER,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1328,9 +1335,13 @@ DEFAULT_CODE_INTERPRETER_PROMPT = """
|
||||
- When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.
|
||||
- After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**
|
||||
- If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.
|
||||
- If a link is provided for an image, audio, or any file, include it in the response exactly as given to ensure the user has access to the original resource.
|
||||
- All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
|
||||
- **If a link to an image, audio, or any file is provided in markdown format, ALWAYS regurgitate explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
|
||||
|
||||
Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
|
||||
|
||||
|
||||
####################################
|
||||
# Vector Database
|
||||
####################################
|
||||
@@ -1741,6 +1752,11 @@ BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
|
||||
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
|
||||
)
|
||||
|
||||
EXA_API_KEY = PersistentConfig(
|
||||
"EXA_API_KEY",
|
||||
"rag.web.search.exa_api_key",
|
||||
os.getenv("EXA_API_KEY", ""),
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_RESULT_COUNT",
|
||||
|
||||
@@ -250,7 +250,7 @@ async def generate_function_chat_completion(
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_model_system_prompt_to_body(params, form_data, user)
|
||||
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
function_module = get_function_module_by_id(request, pipe_id)
|
||||
|
||||
@@ -179,6 +179,7 @@ from open_webui.config import (
|
||||
BING_SEARCH_V7_ENDPOINT,
|
||||
BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
EXA_API_KEY,
|
||||
KAGI_SEARCH_API_KEY,
|
||||
MOJEEK_SEARCH_API_KEY,
|
||||
GOOGLE_PSE_API_KEY,
|
||||
@@ -525,6 +526,7 @@ app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
|
||||
app.state.config.JINA_API_KEY = JINA_API_KEY
|
||||
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
|
||||
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
|
||||
app.state.config.EXA_API_KEY = EXA_API_KEY
|
||||
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
@@ -863,6 +865,7 @@ async def chat_completion(
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
model = request.app.state.MODELS[model_id]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
||||
@@ -880,12 +883,24 @@ async def chat_completion(
|
||||
"files": form_data.get("files", None),
|
||||
"features": form_data.get("features", None),
|
||||
"variables": form_data.get("variables", None),
|
||||
"model": model_info,
|
||||
**(
|
||||
{"function_calling": "native"}
|
||||
if form_data.get("params", {}).get("function_calling") == "native"
|
||||
or (
|
||||
model_info
|
||||
and model_info.params.model_dump().get("function_calling")
|
||||
== "native"
|
||||
)
|
||||
else {}
|
||||
),
|
||||
}
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
form_data, events = await process_chat_payload(
|
||||
form_data, metadata, events = await process_chat_payload(
|
||||
request, form_data, metadata, user, model
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -894,6 +909,7 @@ async def chat_completion(
|
||||
|
||||
try:
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
|
||||
return await process_chat_response(
|
||||
request, response, form_data, user, events, metadata, tasks
|
||||
)
|
||||
|
||||
@@ -15,8 +15,13 @@ from langchain_core.documents import Document
|
||||
from open_webui.config import VECTOR_DB
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
OFFLINE_MODE,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@@ -61,9 +66,7 @@ class VectorSearchRetriever(BaseRetriever):
|
||||
|
||||
|
||||
def query_doc(
|
||||
collection_name: str,
|
||||
query_embedding: list[float],
|
||||
k: int,
|
||||
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
|
||||
):
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
@@ -259,26 +262,31 @@ def get_embedding_function(
|
||||
embedding_batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
return lambda query, user=None: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
func = lambda query: generate_embeddings(
|
||||
func = lambda query, user=None: generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
url=url,
|
||||
key=key,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def generate_multiple(query, func):
|
||||
def generate_multiple(query, user, func):
|
||||
if isinstance(query, list):
|
||||
embeddings = []
|
||||
for i in range(0, len(query), embedding_batch_size):
|
||||
embeddings.extend(func(query[i : i + embedding_batch_size]))
|
||||
embeddings.extend(
|
||||
func(query[i : i + embedding_batch_size], user=user)
|
||||
)
|
||||
return embeddings
|
||||
else:
|
||||
return func(query)
|
||||
return func(query, user)
|
||||
|
||||
return lambda query: generate_multiple(query, func)
|
||||
return lambda query, user=None: generate_multiple(query, user, func)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||
|
||||
|
||||
def get_sources_from_files(
|
||||
@@ -423,7 +431,11 @@ def get_model_path(model: str, update_model: bool = False):
|
||||
|
||||
|
||||
def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
|
||||
model: str,
|
||||
texts: list[str],
|
||||
url: str = "https://api.openai.com/v1",
|
||||
key: str = "",
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
@@ -431,6 +443,16 @@ def generate_openai_batch_embeddings(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
@@ -446,7 +468,7 @@ def generate_openai_batch_embeddings(
|
||||
|
||||
|
||||
def generate_ollama_batch_embeddings(
|
||||
model: str, texts: list[str], url: str, key: str = ""
|
||||
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
@@ -454,6 +476,16 @@ def generate_ollama_batch_embeddings(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
@@ -472,22 +504,29 @@ def generate_ollama_batch_embeddings(
|
||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||
url = kwargs.get("url", "")
|
||||
key = kwargs.get("key", "")
|
||||
user = kwargs.get("user")
|
||||
|
||||
if engine == "ollama":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{"model": model, "texts": text, "url": url, "key": key}
|
||||
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
|
||||
)
|
||||
else:
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{"model": model, "texts": [text], "url": url, "key": key}
|
||||
**{
|
||||
"model": model,
|
||||
"texts": [text],
|
||||
"url": url,
|
||||
"key": key,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "openai":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, url, key)
|
||||
embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
76
backend/open_webui/retrieval/web/exa.py
Normal file
76
backend/open_webui/retrieval/web/exa.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
EXA_API_BASE = "https://api.exa.ai"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExaResult:
|
||||
url: str
|
||||
title: str
|
||||
text: str
|
||||
|
||||
|
||||
def search_exa(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Exa Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Exa Search API key
|
||||
query (str): The query to search for
|
||||
count (int): Number of results to return
|
||||
filter_list (Optional[list[str]]): List of domains to filter results by
|
||||
"""
|
||||
log.info(f"Searching with Exa for query: {query}")
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"numResults": count or 5,
|
||||
"includeDomains": filter_list,
|
||||
"contents": {"text": True, "highlights": True},
|
||||
"type": "auto", # Use the auto search type (keyword or neural)
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{EXA_API_BASE}/search", headers=headers, json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data["results"]:
|
||||
results.append(
|
||||
ExaResult(
|
||||
url=result["url"],
|
||||
title=result["title"],
|
||||
text=result["text"],
|
||||
)
|
||||
)
|
||||
|
||||
log.info(f"Found {len(results)} results")
|
||||
return [
|
||||
SearchResult(
|
||||
link=result.url,
|
||||
title=result.title,
|
||||
snippet=result.text,
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
except Exception as e:
|
||||
log.error(f"Error searching Exa: {e}")
|
||||
return []
|
||||
@@ -48,6 +48,7 @@ def validate_url(url: Union[str, Sequence[str]]):
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
|
||||
valid_urls = []
|
||||
for u in url:
|
||||
@@ -57,6 +58,7 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
|
||||
except ValueError:
|
||||
continue
|
||||
return valid_urls
|
||||
|
||||
def resolve_hostname(hostname):
|
||||
# Get address information
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
|
||||
@@ -71,7 +71,7 @@ def upload_file(
|
||||
)
|
||||
|
||||
try:
|
||||
process_file(request, ProcessFileForm(file_id=id))
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
@@ -193,7 +193,9 @@ async def update_file_data_content_by_id(
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
try:
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=id, content=form_data.content)
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=form_data.content),
|
||||
user=user,
|
||||
)
|
||||
file = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
|
||||
@@ -289,7 +289,9 @@ def add_file_to_knowledge_by_id(
|
||||
# Add content to the vector database
|
||||
try:
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
||||
request,
|
||||
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
@@ -372,7 +374,9 @@ def update_file_from_knowledge_by_id(
|
||||
# Add content to the vector database
|
||||
try:
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
||||
request,
|
||||
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -57,7 +57,7 @@ async def add_memory(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||
"metadata": {"created_at": memory.created_at},
|
||||
}
|
||||
],
|
||||
@@ -82,7 +82,7 @@ async def query_memory(
|
||||
):
|
||||
results = VECTOR_DB_CLIENT.search(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
|
||||
limit=form_data.k,
|
||||
)
|
||||
|
||||
@@ -105,7 +105,7 @@ async def reset_memory_from_vector_db(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
@@ -160,7 +160,9 @@ async def update_memory_by_id(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user
|
||||
),
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
|
||||
@@ -939,6 +939,7 @@ async def generate_completion(
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: Optional[list[dict]] = None
|
||||
images: Optional[list[str]] = None
|
||||
|
||||
|
||||
@@ -950,6 +951,7 @@ class GenerateChatCompletionForm(BaseModel):
|
||||
template: Optional[str] = None
|
||||
stream: Optional[bool] = True
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
tools: Optional[list[dict]] = None
|
||||
|
||||
|
||||
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
|
||||
@@ -1005,7 +1007,7 @@ async def generate_chat_completion(
|
||||
payload["options"] = apply_model_params_to_body_ollama(
|
||||
params, payload["options"]
|
||||
)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
@@ -1158,6 +1160,8 @@ async def generate_openai_chat_completion(
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
metadata = form_data.pop("metadata", None)
|
||||
|
||||
try:
|
||||
completion_form = OpenAIChatCompletionForm(**form_data)
|
||||
except Exception as e:
|
||||
@@ -1184,7 +1188,7 @@ async def generate_openai_chat_completion(
|
||||
|
||||
if params:
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
|
||||
@@ -566,7 +566,7 @@ async def generate_chat_completion(
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
|
||||
@@ -55,6 +55,7 @@ from open_webui.retrieval.web.serply import search_serply
|
||||
from open_webui.retrieval.web.serpstack import search_serpstack
|
||||
from open_webui.retrieval.web.tavily import search_tavily
|
||||
from open_webui.retrieval.web.bing import search_bing
|
||||
from open_webui.retrieval.web.exa import search_exa
|
||||
|
||||
|
||||
from open_webui.retrieval.utils import (
|
||||
@@ -388,6 +389,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
||||
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
},
|
||||
@@ -436,6 +438,7 @@ class WebSearchConfig(BaseModel):
|
||||
jina_api_key: Optional[str] = None
|
||||
bing_search_v7_endpoint: Optional[str] = None
|
||||
bing_search_v7_subscription_key: Optional[str] = None
|
||||
exa_api_key: Optional[str] = None
|
||||
result_count: Optional[int] = None
|
||||
concurrent_requests: Optional[int] = None
|
||||
|
||||
@@ -542,6 +545,8 @@ async def update_rag_config(
|
||||
form_data.web.search.bing_search_v7_subscription_key
|
||||
)
|
||||
|
||||
request.app.state.config.EXA_API_KEY = form_data.web.search.exa_api_key
|
||||
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = (
|
||||
form_data.web.search.result_count
|
||||
)
|
||||
@@ -591,6 +596,7 @@ async def update_rag_config(
|
||||
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
||||
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
},
|
||||
@@ -660,6 +666,7 @@ def save_docs_to_vector_db(
|
||||
overwrite: bool = False,
|
||||
split: bool = True,
|
||||
add: bool = False,
|
||||
user=None,
|
||||
) -> bool:
|
||||
def _get_docs_info(docs: list[Document]) -> str:
|
||||
docs_info = set()
|
||||
@@ -775,7 +782,7 @@ def save_docs_to_vector_db(
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
list(map(lambda x: x.replace("\n", " "), texts))
|
||||
list(map(lambda x: x.replace("\n", " "), texts)), user=user
|
||||
)
|
||||
|
||||
items = [
|
||||
@@ -933,6 +940,7 @@ def process_file(
|
||||
"hash": hash,
|
||||
},
|
||||
add=(True if form_data.collection_name else False),
|
||||
user=user,
|
||||
)
|
||||
|
||||
if result:
|
||||
@@ -990,7 +998,7 @@ def process_text(
|
||||
text_content = form_data.content
|
||||
log.debug(f"text_content: {text_content}")
|
||||
|
||||
result = save_docs_to_vector_db(request, docs, collection_name)
|
||||
result = save_docs_to_vector_db(request, docs, collection_name, user=user)
|
||||
if result:
|
||||
return {
|
||||
"status": True,
|
||||
@@ -1023,7 +1031,9 @@ def process_youtube_video(
|
||||
content = " ".join([doc.page_content for doc in docs])
|
||||
log.debug(f"text_content: {content}")
|
||||
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
save_docs_to_vector_db(
|
||||
request, docs, collection_name, overwrite=True, user=user
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
@@ -1064,7 +1074,9 @@ def process_web(
|
||||
content = " ".join([doc.page_content for doc in docs])
|
||||
|
||||
log.debug(f"text_content: {content}")
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
save_docs_to_vector_db(
|
||||
request, docs, collection_name, overwrite=True, user=user
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
@@ -1099,6 +1111,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
- SERPER_API_KEY
|
||||
- SERPLY_API_KEY
|
||||
- TAVILY_API_KEY
|
||||
- EXA_API_KEY
|
||||
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
@@ -1233,6 +1246,13 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "exa":
|
||||
return search_exa(
|
||||
request.app.state.config.EXA_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No search engine API key found in environment variables")
|
||||
|
||||
@@ -1273,7 +1293,9 @@ async def process_web_search(
|
||||
)
|
||||
docs = [doc async for doc in loader.alazy_load()]
|
||||
# docs = loader.load()
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
save_docs_to_vector_db(
|
||||
request, docs, collection_name, overwrite=True, user=user
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
@@ -1307,7 +1329,9 @@ def query_doc_handler(
|
||||
return query_doc_with_hybrid_search(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=(
|
||||
@@ -1315,12 +1339,16 @@ def query_doc_handler(
|
||||
if form_data.r
|
||||
else request.app.state.config.RELEVANCE_THRESHOLD
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query),
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(
|
||||
form_data.query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
@@ -1349,7 +1377,9 @@ def query_collection_handler(
|
||||
return query_collection_with_hybrid_search(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=(
|
||||
@@ -1362,7 +1392,9 @@ def query_collection_handler(
|
||||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
)
|
||||
|
||||
@@ -1510,6 +1542,7 @@ def process_files_batch(
|
||||
docs=all_docs,
|
||||
collection_name=collection_name,
|
||||
add=True,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Update all files with collection name
|
||||
|
||||
@@ -79,6 +79,7 @@ class ChatPermissions(BaseModel):
|
||||
class FeaturesPermissions(BaseModel):
|
||||
web_search: bool = True
|
||||
image_generation: bool = True
|
||||
code_interpreter: bool = True
|
||||
|
||||
|
||||
class UserPermissions(BaseModel):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import base64
|
||||
|
||||
import asyncio
|
||||
from aiocache import cached
|
||||
@@ -10,6 +12,7 @@ import json
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import ast
|
||||
|
||||
from uuid import uuid4
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -55,6 +58,7 @@ from open_webui.utils.task import (
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
deep_update,
|
||||
get_message_list,
|
||||
add_or_update_system_message,
|
||||
add_or_update_user_message,
|
||||
@@ -69,6 +73,7 @@ from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.tasks import create_task
|
||||
|
||||
from open_webui.config import (
|
||||
CACHE_DIR,
|
||||
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
DEFAULT_CODE_INTERPRETER_PROMPT,
|
||||
)
|
||||
@@ -180,7 +185,7 @@ async def chat_completion_filter_functions_handler(request, body, model, extra_p
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
request: Request, body: dict, user: UserModel, models, extra_params: dict
|
||||
request: Request, body: dict, user: UserModel, models, tools
|
||||
) -> tuple[dict, dict]:
|
||||
async def get_content_from_response(response) -> Optional[str]:
|
||||
content = None
|
||||
@@ -215,35 +220,15 @@ async def chat_completion_tools_handler(
|
||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||
}
|
||||
|
||||
# If tool_ids field is present, call the functions
|
||||
metadata = body.get("metadata", {})
|
||||
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
log.debug(f"{tool_ids=}")
|
||||
if not tool_ids:
|
||||
return body, {}
|
||||
|
||||
skip_files = False
|
||||
sources = []
|
||||
|
||||
task_model_id = get_task_model_id(
|
||||
body["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
tools = get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models[task_model_id],
|
||||
"__messages__": body["messages"],
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
log.info(f"{tools=}")
|
||||
|
||||
skip_files = False
|
||||
sources = []
|
||||
|
||||
specs = [tool["spec"] for tool in tools.values()]
|
||||
tools_specs = json.dumps(specs)
|
||||
@@ -278,6 +263,8 @@ async def chat_completion_tools_handler(
|
||||
result = json.loads(content)
|
||||
|
||||
async def tool_call_handler(tool_call):
|
||||
nonlocal skip_files
|
||||
|
||||
log.debug(f"{tool_call=}")
|
||||
|
||||
tool_function_name = tool_call.get("name", None)
|
||||
@@ -418,7 +405,7 @@ async def chat_web_search_handler(
|
||||
},
|
||||
}
|
||||
)
|
||||
return
|
||||
return form_data
|
||||
|
||||
searchQuery = queries[0]
|
||||
|
||||
@@ -641,7 +628,9 @@ async def chat_completion_files_handler(
|
||||
lambda: get_sources_from_files(
|
||||
files=files,
|
||||
queries=queries,
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
@@ -693,6 +682,7 @@ def apply_params_to_form_data(form_data, model):
|
||||
|
||||
|
||||
async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
form_data = apply_params_to_form_data(form_data, model)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
|
||||
@@ -715,6 +705,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
models = request.app.state.MODELS
|
||||
task_model_id = get_task_model_id(
|
||||
form_data["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
events = []
|
||||
sources = []
|
||||
@@ -799,13 +795,41 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
}
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_tools_handler(
|
||||
request, form_data, user, models, extra_params
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
log.debug(f"{tool_ids=}")
|
||||
|
||||
if tool_ids:
|
||||
# If tool_ids field is present, then get the tools
|
||||
tools = get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models[task_model_id],
|
||||
"__messages__": form_data["messages"],
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
sources.extend(flags.get("sources", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.info(f"{tools=}")
|
||||
|
||||
if metadata.get("function_calling") == "native":
|
||||
# If the function calling is native, then call the tools function calling handler
|
||||
metadata["tools"] = tools
|
||||
form_data["tools"] = [
|
||||
{"type": "function", "function": tool.get("spec", {})}
|
||||
for tool in tools.values()
|
||||
]
|
||||
else:
|
||||
# If the function calling is not native, then call the tools function calling handler
|
||||
try:
|
||||
form_data, flags = await chat_completion_tools_handler(
|
||||
request, form_data, user, models, tools
|
||||
)
|
||||
sources.extend(flags.get("sources", []))
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_files_handler(request, form_data, user)
|
||||
@@ -821,11 +845,11 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
if "document" in source:
|
||||
for doc_idx, doc_context in enumerate(source["document"]):
|
||||
metadata = source.get("metadata")
|
||||
doc_metadata = source.get("metadata")
|
||||
doc_source_id = None
|
||||
|
||||
if metadata:
|
||||
doc_source_id = metadata[doc_idx].get("source", source_id)
|
||||
if doc_metadata:
|
||||
doc_source_id = doc_metadata[doc_idx].get("source", source_id)
|
||||
|
||||
if source_id:
|
||||
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
@@ -882,7 +906,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
}
|
||||
)
|
||||
|
||||
return form_data, events
|
||||
return form_data, metadata, events
|
||||
|
||||
|
||||
async def process_chat_response(
|
||||
@@ -1100,6 +1124,40 @@ async def process_chat_response(
|
||||
for block in content_blocks:
|
||||
if block["type"] == "text":
|
||||
content = f"{content}{block['content'].strip()}\n"
|
||||
elif block["type"] == "tool_calls":
|
||||
attributes = block.get("attributes", {})
|
||||
|
||||
block_content = block.get("content", [])
|
||||
results = block.get("results", [])
|
||||
|
||||
if results:
|
||||
|
||||
result_display_content = ""
|
||||
|
||||
for result in results:
|
||||
tool_call_id = result.get("tool_call_id", "")
|
||||
tool_name = ""
|
||||
|
||||
for tool_call in block_content:
|
||||
if tool_call.get("id", "") == tool_call_id:
|
||||
tool_name = tool_call.get("function", {}).get(
|
||||
"name", ""
|
||||
)
|
||||
break
|
||||
|
||||
result_display_content = f"{result_display_content}\n> {tool_name}: {result.get('content', '')}"
|
||||
|
||||
if not raw:
|
||||
content = f'{content}\n<details type="tool_calls" done="true" content="{html.escape(json.dumps(block_content))}" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n{result_display_content}\n</details>\n'
|
||||
else:
|
||||
tool_calls_display_content = ""
|
||||
|
||||
for tool_call in block_content:
|
||||
tool_calls_display_content = f"{tool_calls_display_content}\n> Executing {tool_call.get('function', {}).get('name', '')}"
|
||||
|
||||
if not raw:
|
||||
content = f'{content}\n<details type="tool_calls" done="false" content="{html.escape(json.dumps(block_content))}">\n<summary>Tool Executing...</summary>\n{tool_calls_display_content}\n</details>\n'
|
||||
|
||||
elif block["type"] == "reasoning":
|
||||
reasoning_display_content = "\n".join(
|
||||
(f"> {line}" if not line.startswith(">") else line)
|
||||
@@ -1108,16 +1166,16 @@ async def process_chat_response(
|
||||
|
||||
reasoning_duration = block.get("duration", None)
|
||||
|
||||
if reasoning_duration:
|
||||
if reasoning_duration is not None:
|
||||
if raw:
|
||||
content = f'{content}<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
else:
|
||||
content = f'{content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
else:
|
||||
if raw:
|
||||
content = f'{content}<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
else:
|
||||
content = f'{content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
|
||||
elif block["type"] == "code_interpreter":
|
||||
attributes = block.get("attributes", {})
|
||||
@@ -1128,20 +1186,20 @@ async def process_chat_response(
|
||||
output = html.escape(json.dumps(output))
|
||||
|
||||
if raw:
|
||||
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
|
||||
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
|
||||
else:
|
||||
content = f'{content}<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
content = f'{content}\n<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
else:
|
||||
if raw:
|
||||
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
|
||||
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
|
||||
else:
|
||||
content = f'{content}<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
content = f'{content}\n<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
|
||||
else:
|
||||
block_content = str(block["content"]).strip()
|
||||
content = f"{content}{block['type']}: {block_content}\n"
|
||||
|
||||
return content
|
||||
return content.strip()
|
||||
|
||||
def tag_content_handler(content_type, tags, content, content_blocks):
|
||||
end_flag = False
|
||||
@@ -1149,6 +1207,8 @@ async def process_chat_response(
|
||||
def extract_attributes(tag_content):
|
||||
"""Extract attributes from a tag if they exist."""
|
||||
attributes = {}
|
||||
if not tag_content: # Ensure tag_content is not None
|
||||
return attributes
|
||||
# Match attributes in the format: key="value" (ignores single quotes for simplicity)
|
||||
matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content)
|
||||
for key, value in matches:
|
||||
@@ -1158,17 +1218,35 @@ async def process_chat_response(
|
||||
if content_blocks[-1]["type"] == "text":
|
||||
for tag in tags:
|
||||
# Match start tag e.g., <tag> or <tag attr="value">
|
||||
start_tag_pattern = rf"<{tag}(.*?)>"
|
||||
start_tag_pattern = rf"<{tag}(\s.*?)?>"
|
||||
match = re.search(start_tag_pattern, content)
|
||||
if match:
|
||||
# Extract attributes in the tag (if present)
|
||||
attributes = extract_attributes(match.group(1))
|
||||
attr_content = (
|
||||
match.group(1) if match.group(1) else ""
|
||||
) # Ensure it's not None
|
||||
attributes = extract_attributes(
|
||||
attr_content
|
||||
) # Extract attributes safely
|
||||
|
||||
# Capture everything before and after the matched tag
|
||||
before_tag = content[
|
||||
: match.start()
|
||||
] # Content before opening tag
|
||||
after_tag = content[
|
||||
match.end() :
|
||||
] # Content after opening tag
|
||||
|
||||
# Remove the start tag from the currently handling text block
|
||||
content_blocks[-1]["content"] = content_blocks[-1][
|
||||
"content"
|
||||
].replace(match.group(0), "")
|
||||
|
||||
if before_tag:
|
||||
content_blocks[-1]["content"] = before_tag
|
||||
|
||||
if not content_blocks[-1]["content"]:
|
||||
content_blocks.pop()
|
||||
|
||||
# Append the new block
|
||||
content_blocks.append(
|
||||
{
|
||||
@@ -1179,52 +1257,100 @@ async def process_chat_response(
|
||||
"started_at": time.time(),
|
||||
}
|
||||
)
|
||||
|
||||
if after_tag:
|
||||
content_blocks[-1]["content"] = after_tag
|
||||
|
||||
break
|
||||
elif content_blocks[-1]["type"] == content_type:
|
||||
tag = content_blocks[-1]["tag"]
|
||||
# Match end tag e.g., </tag>
|
||||
end_tag_pattern = rf"</{tag}>"
|
||||
|
||||
# Check if the content has the end tag
|
||||
if re.search(end_tag_pattern, content):
|
||||
end_flag = True
|
||||
|
||||
block_content = content_blocks[-1]["content"]
|
||||
# Strip start and end tags from the content
|
||||
start_tag_pattern = rf"<{tag}(.*?)>"
|
||||
block_content = re.sub(
|
||||
start_tag_pattern, "", block_content
|
||||
).strip()
|
||||
block_content = re.sub(
|
||||
end_tag_pattern, "", block_content
|
||||
).strip()
|
||||
|
||||
end_tag_regex = re.compile(end_tag_pattern, re.DOTALL)
|
||||
split_content = end_tag_regex.split(block_content, maxsplit=1)
|
||||
|
||||
# Content inside the tag
|
||||
block_content = (
|
||||
split_content[0].strip() if split_content else ""
|
||||
)
|
||||
|
||||
# Leftover content (everything after `</tag>`)
|
||||
leftover_content = (
|
||||
split_content[1].strip() if len(split_content) > 1 else ""
|
||||
)
|
||||
|
||||
if block_content:
|
||||
end_flag = True
|
||||
content_blocks[-1]["content"] = block_content
|
||||
content_blocks[-1]["ended_at"] = time.time()
|
||||
content_blocks[-1]["duration"] = int(
|
||||
content_blocks[-1]["ended_at"]
|
||||
- content_blocks[-1]["started_at"]
|
||||
)
|
||||
|
||||
# Reset the content_blocks by appending a new text block
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
# Clean processed content
|
||||
content = re.sub(
|
||||
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
|
||||
"",
|
||||
content,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
if content_type != "code_interpreter":
|
||||
if leftover_content:
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": leftover_content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
# Remove the block if content is empty
|
||||
content_blocks.pop()
|
||||
|
||||
if leftover_content:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": leftover_content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
# Clean processed content
|
||||
content = re.sub(
|
||||
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
|
||||
"",
|
||||
content,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
|
||||
return content, content_blocks, end_flag
|
||||
|
||||
message = Chats.get_message_by_id_and_message_id(
|
||||
metadata["chat_id"], metadata["message_id"]
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
content = message.get("content", "") if message else ""
|
||||
content_blocks = [
|
||||
{
|
||||
@@ -1235,9 +1361,18 @@ async def process_chat_response(
|
||||
|
||||
# We might want to disable this by default
|
||||
DETECT_REASONING = True
|
||||
DETECT_CODE_INTERPRETER = True
|
||||
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
|
||||
"code_interpreter", False
|
||||
)
|
||||
|
||||
reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
|
||||
reasoning_tags = [
|
||||
"think",
|
||||
"thinking",
|
||||
"reason",
|
||||
"reasoning",
|
||||
"thought",
|
||||
"Thought",
|
||||
]
|
||||
code_interpreter_tags = ["code_interpreter"]
|
||||
|
||||
try:
|
||||
@@ -1262,6 +1397,8 @@ async def process_chat_response(
|
||||
nonlocal content
|
||||
nonlocal content_blocks
|
||||
|
||||
response_tool_calls = []
|
||||
|
||||
async for line in response.body_iterator:
|
||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
||||
data = line
|
||||
@@ -1294,10 +1431,54 @@ async def process_chat_response(
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
value = choices[0].get("delta", {}).get("content")
|
||||
delta = choices[0].get("delta", {})
|
||||
delta_tool_calls = delta.get("tool_calls", None)
|
||||
|
||||
if delta_tool_calls:
|
||||
for delta_tool_call in delta_tool_calls:
|
||||
tool_call_index = delta_tool_call.get("index")
|
||||
|
||||
if tool_call_index is not None:
|
||||
if (
|
||||
len(response_tool_calls)
|
||||
<= tool_call_index
|
||||
):
|
||||
response_tool_calls.append(
|
||||
delta_tool_call
|
||||
)
|
||||
else:
|
||||
delta_name = delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("name")
|
||||
delta_arguments = delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("arguments")
|
||||
|
||||
if delta_name:
|
||||
response_tool_calls[
|
||||
tool_call_index
|
||||
]["function"]["name"] += delta_name
|
||||
|
||||
if delta_arguments:
|
||||
response_tool_calls[
|
||||
tool_call_index
|
||||
]["function"][
|
||||
"arguments"
|
||||
] += delta_arguments
|
||||
|
||||
value = delta.get("content")
|
||||
|
||||
if value:
|
||||
content = f"{content}{value}"
|
||||
|
||||
if not content_blocks:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
content_blocks[-1]["content"] = (
|
||||
content_blocks[-1]["content"] + value
|
||||
)
|
||||
@@ -1357,14 +1538,46 @@ async def process_chat_response(
|
||||
log.debug("Error: ", e)
|
||||
continue
|
||||
|
||||
# Clean up the last text block
|
||||
if content_blocks[-1]["type"] == "text":
|
||||
content_blocks[-1]["content"] = content_blocks[-1][
|
||||
"content"
|
||||
].strip()
|
||||
if content_blocks:
|
||||
# Clean up the last text block
|
||||
if content_blocks[-1]["type"] == "text":
|
||||
content_blocks[-1]["content"] = content_blocks[-1][
|
||||
"content"
|
||||
].strip()
|
||||
|
||||
if not content_blocks[-1]["content"]:
|
||||
content_blocks.pop()
|
||||
if not content_blocks[-1]["content"]:
|
||||
content_blocks.pop()
|
||||
|
||||
if not content_blocks:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
if response_tool_calls:
|
||||
tool_calls.append(response_tool_calls)
|
||||
|
||||
if response.background:
|
||||
await response.background()
|
||||
|
||||
await stream_body_handler(response)
|
||||
|
||||
MAX_TOOL_CALL_RETRIES = 5
|
||||
tool_call_retries = 0
|
||||
|
||||
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
|
||||
tool_call_retries += 1
|
||||
|
||||
response_tool_calls = tool_calls.pop(0)
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_calls",
|
||||
"content": response_tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
@@ -1375,37 +1588,54 @@ async def process_chat_response(
|
||||
}
|
||||
)
|
||||
|
||||
if response.background:
|
||||
await response.background()
|
||||
tools = metadata.get("tools", {})
|
||||
|
||||
await stream_body_handler(response)
|
||||
results = []
|
||||
for tool_call in response_tool_calls:
|
||||
print("\n\n" + str(tool_call) + "\n\n")
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("function", {}).get("name", "")
|
||||
|
||||
MAX_RETRIES = 5
|
||||
retries = 0
|
||||
|
||||
while (
|
||||
content_blocks[-1]["type"] == "code_interpreter"
|
||||
and retries < MAX_RETRIES
|
||||
):
|
||||
retries += 1
|
||||
log.debug(f"Attempt count: {retries}")
|
||||
|
||||
output = ""
|
||||
try:
|
||||
if content_blocks[-1]["attributes"].get("type") == "code":
|
||||
output = await event_caller(
|
||||
{
|
||||
"type": "execute:python",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": content_blocks[-1]["content"],
|
||||
},
|
||||
}
|
||||
tool_function_params = {}
|
||||
try:
|
||||
# json.loads cannot be used because some models do not produce valid JSON
|
||||
tool_function_params = ast.literal_eval(
|
||||
tool_call.get("function", {}).get("arguments", "{}")
|
||||
)
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
|
||||
tool_result = None
|
||||
|
||||
if tool_name in tools:
|
||||
tool = tools[tool_name]
|
||||
spec = tool.get("spec", {})
|
||||
|
||||
try:
|
||||
required_params = spec.get("parameters", {}).get(
|
||||
"required", []
|
||||
)
|
||||
tool_function = tool["callable"]
|
||||
tool_function_params = {
|
||||
k: v
|
||||
for k, v in tool_function_params.items()
|
||||
if k in required_params
|
||||
}
|
||||
tool_result = await tool_function(
|
||||
**tool_function_params
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = str(e)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
content_blocks[-1]["results"] = results
|
||||
|
||||
content_blocks[-1]["output"] = output
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
@@ -1435,7 +1665,16 @@ async def process_chat_response(
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks, raw=True
|
||||
),
|
||||
"tool_calls": response_tool_calls,
|
||||
},
|
||||
*[
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result["tool_call_id"],
|
||||
"content": result["content"],
|
||||
}
|
||||
for result in results
|
||||
],
|
||||
],
|
||||
},
|
||||
user,
|
||||
@@ -1449,6 +1688,110 @@ async def process_chat_response(
|
||||
log.debug(e)
|
||||
break
|
||||
|
||||
if DETECT_CODE_INTERPRETER:
|
||||
MAX_RETRIES = 5
|
||||
retries = 0
|
||||
|
||||
while (
|
||||
content_blocks[-1]["type"] == "code_interpreter"
|
||||
and retries < MAX_RETRIES
|
||||
):
|
||||
retries += 1
|
||||
log.debug(f"Attempt count: {retries}")
|
||||
|
||||
output = ""
|
||||
try:
|
||||
if content_blocks[-1]["attributes"].get("type") == "code":
|
||||
output = await event_caller(
|
||||
{
|
||||
"type": "execute:python",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": content_blocks[-1]["content"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(output, dict):
|
||||
stdout = output.get("stdout", "")
|
||||
|
||||
if stdout:
|
||||
stdoutLines = stdout.split("\n")
|
||||
for idx, line in enumerate(stdoutLines):
|
||||
if "data:image/png;base64" in line:
|
||||
id = str(uuid4())
|
||||
|
||||
# ensure the path exists
|
||||
os.makedirs(
|
||||
os.path.join(CACHE_DIR, "images"),
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
image_path = os.path.join(
|
||||
CACHE_DIR,
|
||||
f"images/{id}.png",
|
||||
)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(
|
||||
base64.b64decode(
|
||||
line.split(",")[1]
|
||||
)
|
||||
)
|
||||
|
||||
stdoutLines[idx] = (
|
||||
f""
|
||||
)
|
||||
|
||||
output["stdout"] = "\n".join(stdoutLines)
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
|
||||
content_blocks[-1]["output"] = output
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": {
|
||||
"content": serialize_content_blocks(content_blocks),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
res = await generate_chat_completion(
|
||||
request,
|
||||
{
|
||||
"model": model_id,
|
||||
"stream": True,
|
||||
"messages": [
|
||||
*form_data["messages"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks, raw=True
|
||||
),
|
||||
},
|
||||
],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
if isinstance(res, StreamingResponse):
|
||||
await stream_body_handler(res)
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
break
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
data = {
|
||||
"done": True,
|
||||
|
||||
@@ -7,6 +7,18 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
import collections.abc
|
||||
|
||||
|
||||
def deep_update(d, u):
|
||||
for k, v in u.items():
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
d[k] = deep_update(d.get(k, {}), v)
|
||||
else:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
def get_message_list(messages, message_id):
|
||||
"""
|
||||
Reconstructs a list of messages in order up to the specified message_id.
|
||||
@@ -179,15 +191,25 @@ def openai_chat_message_template(model: str):
|
||||
|
||||
|
||||
def openai_chat_chunk_message_template(
|
||||
model: str, message: Optional[str] = None, usage: Optional[dict] = None
|
||||
model: str,
|
||||
content: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict]] = None,
|
||||
usage: Optional[dict] = None,
|
||||
) -> dict:
|
||||
template = openai_chat_message_template(model)
|
||||
template["object"] = "chat.completion.chunk"
|
||||
if message:
|
||||
template["choices"][0]["delta"] = {"content": message}
|
||||
else:
|
||||
|
||||
template["choices"][0]["index"] = 0
|
||||
template["choices"][0]["delta"] = {}
|
||||
|
||||
if content:
|
||||
template["choices"][0]["delta"]["content"] = content
|
||||
|
||||
if tool_calls:
|
||||
template["choices"][0]["delta"]["tool_calls"] = tool_calls
|
||||
|
||||
if not content and not tool_calls:
|
||||
template["choices"][0]["finish_reason"] = "stop"
|
||||
template["choices"][0]["delta"] = {}
|
||||
|
||||
if usage:
|
||||
template["usage"] = usage
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from open_webui.utils.task import prompt_variables_template
|
||||
from open_webui.utils.task import prompt_template, prompt_variables_template
|
||||
from open_webui.utils.misc import (
|
||||
add_or_update_system_message,
|
||||
)
|
||||
@@ -8,16 +8,28 @@ from typing import Callable, Optional
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_system_prompt_to_body(
|
||||
params: dict, form_data: dict, metadata: Optional[dict] = None
|
||||
params: dict, form_data: dict, metadata: Optional[dict] = None, user=None
|
||||
) -> dict:
|
||||
system = params.get("system", None)
|
||||
if not system:
|
||||
return form_data
|
||||
|
||||
# Legacy (API Usage)
|
||||
if user:
|
||||
template_params = {
|
||||
"user_name": user.name,
|
||||
"user_location": user.info.get("location") if user.info else None,
|
||||
}
|
||||
else:
|
||||
template_params = {}
|
||||
|
||||
system = prompt_template(system, **template_params)
|
||||
|
||||
# Metadata (WebUI Usage)
|
||||
if metadata:
|
||||
print("apply_model_system_prompt_to_body: metadata", metadata)
|
||||
variables = metadata.get("variables", {})
|
||||
system = prompt_variables_template(system, variables)
|
||||
if variables:
|
||||
system = prompt_variables_template(system, variables)
|
||||
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system, form_data.get("messages", [])
|
||||
@@ -154,6 +166,9 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
)
|
||||
ollama_payload["stream"] = openai_payload.get("stream", False)
|
||||
|
||||
if "tools" in openai_payload:
|
||||
ollama_payload["tools"] = openai_payload["tools"]
|
||||
|
||||
if "format" in openai_payload:
|
||||
ollama_payload["format"] = openai_payload["format"]
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from uuid import uuid4
|
||||
from open_webui.utils.misc import (
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
@@ -60,6 +61,23 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
|
||||
model = data.get("model", "ollama")
|
||||
message_content = data.get("message", {}).get("content", "")
|
||||
tool_calls = data.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
if tool_calls:
|
||||
openai_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
openai_tool_call = {
|
||||
"index": tool_call.get("index", 0),
|
||||
"id": tool_call.get("id", f"call_{str(uuid4())}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("function", {}).get("name", ""),
|
||||
"arguments": f"{tool_call.get('function', {}).get('arguments', {})}",
|
||||
},
|
||||
}
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
|
||||
done = data.get("done", False)
|
||||
|
||||
usage = None
|
||||
@@ -105,7 +123,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
}
|
||||
|
||||
data = openai_chat_chunk_message_template(
|
||||
model, message_content if not done else None, usage
|
||||
model, message_content if not done else None, openai_tool_calls, usage
|
||||
)
|
||||
|
||||
line = f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
@@ -61,6 +61,12 @@ def get_tools(
|
||||
)
|
||||
|
||||
for spec in tools.specs:
|
||||
# TODO: Fix hack for OpenAI API
|
||||
# Some times breaks OpenAI but others don't. Leaving the comment
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val["type"] == "str":
|
||||
val["type"] = "string"
|
||||
|
||||
# Remove internal parameters
|
||||
spec["parameters"]["properties"] = {
|
||||
key: val
|
||||
@@ -73,6 +79,13 @@ def get_tools(
|
||||
# convert to function that takes only model params and inserts custom params
|
||||
original_func = getattr(module, function_name)
|
||||
callable = apply_extra_params_to_tool_function(original_func, extra_params)
|
||||
|
||||
if callable.__doc__ and callable.__doc__.strip() != "":
|
||||
s = re.split(":(param|return)", callable.__doc__, 1)
|
||||
spec["description"] = s[0]
|
||||
else:
|
||||
spec["description"] = function_name
|
||||
|
||||
# TODO: This needs to be a pydantic model
|
||||
tool_dict = {
|
||||
"toolkit_id": tool_id,
|
||||
|
||||
Reference in New Issue
Block a user