This commit is contained in:
Timothy Jaeryang Baek
2024-12-12 22:28:42 -08:00
parent 8c38708827
commit 1197c640c4
6 changed files with 664 additions and 825 deletions

View File

@@ -41,8 +41,6 @@ from starlette.responses import Response, StreamingResponse
from open_webui.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.routers import (
audio,
@@ -74,12 +72,6 @@ from open_webui.routers.retrieval import (
get_ef,
get_rf,
)
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
)
from open_webui.retrieval.utils import get_sources_from_files
from open_webui.internal.db import Session
@@ -87,8 +79,6 @@ from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.models.users import UserModel, Users
from open_webui.constants import TASKS
from open_webui.config import (
# Ollama
ENABLE_OLLAMA_API,
@@ -274,43 +264,22 @@ from open_webui.env import (
)
from open_webui.utils.models import get_all_models, get_all_base_models
from open_webui.utils.models import (
get_all_models,
get_all_base_models,
check_model_access,
)
from open_webui.utils.chat import (
generate_chat_completion as chat_completion_handler,
chat_completed as chat_completed_handler,
chat_action as chat_action_handler,
)
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.misc import (
add_or_update_system_message,
get_last_user_message,
prepend_to_first_user_message_content,
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.utils.task import (
get_task_model_id,
rag_template,
tools_function_calling_generation_template,
)
from open_webui.utils.tools import get_tools
from open_webui.utils.middleware import process_chat_payload, process_chat_response
from open_webui.utils.access_control import has_access
from open_webui.utils.auth import (
decode_token,
get_admin_user,
get_current_user,
get_http_authorization_cred,
get_verified_user,
)
from open_webui.utils.oauth import oauth_manager
@@ -665,634 +634,6 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
app.state.MODELS = {}
##################################
#
# ChatCompletion Middleware
#
##################################
async def chat_completion_filter_functions_handler(body, model, extra_params):
skip_files = None
def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [
function.id for function in Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
return filter_ids
filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in app.state.FUNCTIONS:
function_module = app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "inlet"):
continue
try:
inlet = function_module.inlet
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": body} | {
k: v
for k, v in {
**extra_params,
"__model__": model,
"__id__": filter_id,
}.items()
if k in sig.parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"):
try:
params["__user__"]["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, params["__user__"]["id"]
)
)
except Exception as e:
print(e)
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
body = inlet(**params)
except Exception as e:
print(f"Error: {e}")
raise e
if skip_files and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
return body, {}
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, extra_params: dict
) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]:
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
return content
def get_tools_function_calling_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages)
history = "\n".join(
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
)
prompt = f"History:\n{history}\nQuery: {user_message}"
return {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
# If tool_ids field is present, call the functions
metadata = body.get("metadata", {})
tool_ids = metadata.get("tool_ids", None)
log.debug(f"{tool_ids=}")
if not tool_ids:
return body, {}
skip_files = False
sources = []
task_model_id = get_task_model_id(
body["model"],
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
tools = get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models[task_model_id],
"__messages__": body["messages"],
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
specs = [tool["spec"] for tool in tools.values()]
tools_specs = json.dumps(specs)
if app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
else:
template = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
tools_function_calling_prompt = tools_function_calling_generation_template(
template, tools_specs
)
log.info(f"{tools_function_calling_prompt=}")
payload = get_tools_function_calling_payload(
body["messages"], task_model_id, tools_function_calling_prompt
)
try:
payload = process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
response = await generate_chat_completions(form_data=payload, user=user)
log.debug(f"{response=}")
content = await get_content_from_response(response)
log.debug(f"{content=}")
if not content:
return body, {}
try:
content = content[content.find("{") : content.rfind("}") + 1]
if not content:
raise Exception("No JSON object found in the response")
result = json.loads(content)
tool_function_name = result.get("name", None)
if tool_function_name not in tools:
return body, {}
tool_function_params = result.get("parameters", {})
try:
required_params = (
tools[tool_function_name]
.get("spec", {})
.get("parameters", {})
.get("required", [])
)
tool_function = tools[tool_function_name]["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in required_params
}
tool_output = await tool_function(**tool_function_params)
except Exception as e:
tool_output = str(e)
if isinstance(tool_output, str):
if tools[tool_function_name]["citation"]:
sources.append(
{
"source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
},
"document": [tool_output],
"metadata": [
{
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
}
],
}
)
else:
sources.append(
{
"source": {},
"document": [tool_output],
"metadata": [
{
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
}
],
}
)
if tools[tool_function_name]["file_handler"]:
skip_files = True
except Exception as e:
log.exception(f"Error: {e}")
content = None
except Exception as e:
log.exception(f"Error: {e}")
content = None
log.debug(f"tool_contexts: {sources}")
if skip_files and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
return body, {"sources": sources}
async def chat_completion_files_handler(
request: Request, body: dict, user: UserModel
) -> tuple[dict, dict[str, list]]:
sources = []
if files := body.get("metadata", {}).get("files", None):
try:
queries_response = await generate_queries(
{
"model": body["model"],
"messages": body["messages"],
"type": "retrieval",
},
user,
)
queries_response = queries_response["choices"][0]["message"]["content"]
try:
bracket_start = queries_response.find("{")
bracket_end = queries_response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
queries_response = queries_response[bracket_start:bracket_end]
queries_response = json.loads(queries_response)
except Exception as e:
queries_response = {"queries": [queries_response]}
queries = queries_response.get("queries", [])
except Exception as e:
queries = []
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
sources = get_sources_from_files(
files=files,
queries=queries,
embedding_function=request.app.state.EMBEDDING_FUNCTION,
k=request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
log.debug(f"rag_contexts:sources: {sources}")
return body, {"sources": sources}
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if not (
request.method == "POST"
and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
)
):
return await call_next(request)
log.debug(f"request.url.path: {request.url.path}")
await get_all_models(request)
models = app.state.MODELS
async def get_body_and_model_and_user(request, models):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
try:
body, model, user = await get_body_and_model_and_user(request, models)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
model_info = Models.get_model_by_id(model["id"])
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
if model.get("arena"):
if not has_access(
user.id,
type="read",
access_control=model.get("info", {})
.get("meta", {})
.get("access_control", {}),
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
else:
if not model_info:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={"detail": "Model not found"},
)
elif not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "User does not have access to the model"},
)
metadata = {
"chat_id": body.pop("chat_id", None),
"message_id": body.pop("id", None),
"session_id": body.pop("session_id", None),
"tool_ids": body.get("tool_ids", None),
"files": body.get("files", None),
}
body["metadata"] = metadata
extra_params = {
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
}
# Initialize data_items to store additional data to be sent to the client
# Initialize contexts and citation
data_items = []
sources = []
try:
body, flags = await chat_completion_filter_functions_handler(
body, model, extra_params
)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
tool_ids = body.pop("tool_ids", None)
files = body.pop("files", None)
metadata = {
**metadata,
"tool_ids": tool_ids,
"files": files,
}
body["metadata"] = metadata
try:
body, flags = await chat_completion_tools_handler(
request, body, user, models, extra_params
)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
try:
body, flags = await chat_completion_files_handler(request, body, user)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
for source_idx, source in enumerate(sources):
source_id = source.get("source", {}).get("name", "")
if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]):
metadata = source.get("metadata")
doc_source_id = None
if metadata:
doc_source_id = metadata[doc_idx].get("source", source_id)
if source_id:
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
else:
# If there is no source_id, then do not include the source_id tag
context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
context_string = context_string.strip()
prompt = get_last_user_message(body["messages"])
if prompt is None:
raise Exception("No user message found")
if (
app.state.config.RELEVANCE_THRESHOLD == 0
and context_string.strip() == ""
):
log.debug(
f"With a 0 relevancy threshold for RAG, the context cannot be empty"
)
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
body["messages"] = prepend_to_first_user_message_content(
rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
body["messages"],
)
else:
body["messages"] = add_or_update_system_message(
rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
body["messages"],
)
# If there are citations, add them to the data_items
sources = [
source for source in sources if source.get("source", {}).get("name", "")
]
if len(sources) > 0:
data_items.append({"sources": sources})
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
]
response = await call_next(request)
if not isinstance(response, StreamingResponse):
return response
content_type = response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type
if not is_openai and not is_ollama:
return response
def wrap_item(item):
return f"data: {item}\n\n" if is_openai else f"{item}\n"
async def stream_wrapper(original_generator, data_items):
for item in data_items:
yield wrap_item(json.dumps(item))
async for data in original_generator:
yield data
return StreamingResponse(
stream_wrapper(response.body_iterator, data_items),
headers=dict(response.headers),
)
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
app.add_middleware(ChatCompletionMiddleware)
class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if not (
request.method == "POST"
and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
)
):
return await call_next(request)
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
try:
user = get_current_user(
request,
get_http_authorization_cred(request.headers["Authorization"]),
)
except KeyError as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Not authenticated"},
)
except HTTPException as e:
return JSONResponse(
status_code=e.status_code,
content={"detail": e.detail},
)
await get_all_models(request)
models = app.state.MODELS
try:
data = process_pipeline_inlet_filter(request, data, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
]
response = await call_next(request)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
app.add_middleware(PipelineMiddleware)
class RedirectMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
@@ -1471,8 +812,32 @@ async def chat_completion(
user=Depends(get_verified_user),
bypass_filter: bool = False,
):
try:
return await chat_completion_handler(request, form_data, user, bypass_filter)
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
model = request.app.state.MODELS[model_id]
# Check if user has access to the model
if not bypass_filter and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
form_data, events = await process_chat_payload(request, form_data, user, model)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
try:
response = await chat_completion_handler(
request, form_data, user, bypass_filter
)
return await process_chat_response(response, events)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -1480,6 +845,7 @@ async def chat_completion(
)
# Alias for chat_completion (Legacy)
generate_chat_completions = chat_completion
generate_chat_completion = chat_completion