mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
refac
This commit is contained in:
@@ -41,8 +41,6 @@ from starlette.responses import Response, StreamingResponse
|
||||
from open_webui.socket.main import (
|
||||
app as socket_app,
|
||||
periodic_usage_pool_cleanup,
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
)
|
||||
from open_webui.routers import (
|
||||
audio,
|
||||
@@ -74,12 +72,6 @@ from open_webui.routers.retrieval import (
|
||||
get_ef,
|
||||
get_rf,
|
||||
)
|
||||
from open_webui.routers.pipelines import (
|
||||
process_pipeline_inlet_filter,
|
||||
)
|
||||
|
||||
from open_webui.retrieval.utils import get_sources_from_files
|
||||
|
||||
|
||||
from open_webui.internal.db import Session
|
||||
|
||||
@@ -87,8 +79,6 @@ from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
from open_webui.models.users import UserModel, Users
|
||||
|
||||
|
||||
from open_webui.constants import TASKS
|
||||
from open_webui.config import (
|
||||
# Ollama
|
||||
ENABLE_OLLAMA_API,
|
||||
@@ -274,43 +264,22 @@ from open_webui.env import (
|
||||
)
|
||||
|
||||
|
||||
from open_webui.utils.models import get_all_models, get_all_base_models
|
||||
from open_webui.utils.models import (
|
||||
get_all_models,
|
||||
get_all_base_models,
|
||||
check_model_access,
|
||||
)
|
||||
from open_webui.utils.chat import (
|
||||
generate_chat_completion as chat_completion_handler,
|
||||
chat_completed as chat_completed_handler,
|
||||
chat_action as chat_action_handler,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.misc import (
|
||||
add_or_update_system_message,
|
||||
get_last_user_message,
|
||||
prepend_to_first_user_message_content,
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
||||
from open_webui.utils.response import (
|
||||
convert_response_ollama_to_openai,
|
||||
convert_streaming_response_ollama_to_openai,
|
||||
)
|
||||
|
||||
from open_webui.utils.task import (
|
||||
get_task_model_id,
|
||||
rag_template,
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.middleware import process_chat_payload, process_chat_response
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
from open_webui.utils.auth import (
|
||||
decode_token,
|
||||
get_admin_user,
|
||||
get_current_user,
|
||||
get_http_authorization_cred,
|
||||
get_verified_user,
|
||||
)
|
||||
from open_webui.utils.oauth import oauth_manager
|
||||
@@ -665,634 +634,6 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
||||
|
||||
app.state.MODELS = {}
|
||||
|
||||
##################################
|
||||
#
|
||||
# ChatCompletion Middleware
|
||||
#
|
||||
##################################
|
||||
|
||||
|
||||
async def chat_completion_filter_functions_handler(body, model, extra_params):
|
||||
skip_files = None
|
||||
|
||||
def get_filter_function_ids(model):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [
|
||||
function.id for function in Functions.get_global_filter_functions()
|
||||
]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
filter_ids.sort(key=get_priority)
|
||||
return filter_ids
|
||||
|
||||
filter_ids = get_filter_function_ids(model)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in app.state.FUNCTIONS:
|
||||
function_module = app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
if not hasattr(function_module, "inlet"):
|
||||
continue
|
||||
|
||||
try:
|
||||
inlet = function_module.inlet
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(inlet)
|
||||
params = {"body": body} | {
|
||||
k: v
|
||||
for k, v in {
|
||||
**extra_params,
|
||||
"__model__": model,
|
||||
"__id__": filter_id,
|
||||
}.items()
|
||||
if k in sig.parameters
|
||||
}
|
||||
|
||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, params["__user__"]["id"]
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
body = await inlet(**params)
|
||||
else:
|
||||
body = inlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
raise e
|
||||
|
||||
if skip_files and "files" in body.get("metadata", {}):
|
||||
del body["metadata"]["files"]
|
||||
|
||||
return body, {}
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
request: Request, body: dict, user: UserModel, models, extra_params: dict
|
||||
) -> tuple[dict, dict]:
|
||||
async def get_content_from_response(response) -> Optional[str]:
|
||||
content = None
|
||||
if hasattr(response, "body_iterator"):
|
||||
async for chunk in response.body_iterator:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
|
||||
# Cleanup any remaining background tasks if necessary
|
||||
if response.background is not None:
|
||||
await response.background()
|
||||
else:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
return content
|
||||
|
||||
def get_tools_function_calling_payload(messages, task_model_id, content):
|
||||
user_message = get_last_user_message(messages)
|
||||
history = "\n".join(
|
||||
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
||||
for message in messages[::-1][:4]
|
||||
)
|
||||
|
||||
prompt = f"History:\n{history}\nQuery: {user_message}"
|
||||
|
||||
return {
|
||||
"model": task_model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": content},
|
||||
{"role": "user", "content": f"Query: {prompt}"},
|
||||
],
|
||||
"stream": False,
|
||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||
}
|
||||
|
||||
# If tool_ids field is present, call the functions
|
||||
metadata = body.get("metadata", {})
|
||||
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
log.debug(f"{tool_ids=}")
|
||||
if not tool_ids:
|
||||
return body, {}
|
||||
|
||||
skip_files = False
|
||||
sources = []
|
||||
|
||||
task_model_id = get_task_model_id(
|
||||
body["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
tools = get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models[task_model_id],
|
||||
"__messages__": body["messages"],
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
log.info(f"{tools=}")
|
||||
|
||||
specs = [tool["spec"] for tool in tools.values()]
|
||||
tools_specs = json.dumps(specs)
|
||||
|
||||
if app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
|
||||
|
||||
tools_function_calling_prompt = tools_function_calling_generation_template(
|
||||
template, tools_specs
|
||||
)
|
||||
log.info(f"{tools_function_calling_prompt=}")
|
||||
payload = get_tools_function_calling_payload(
|
||||
body["messages"], task_model_id, tools_function_calling_prompt
|
||||
)
|
||||
|
||||
try:
|
||||
payload = process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
response = await generate_chat_completions(form_data=payload, user=user)
|
||||
log.debug(f"{response=}")
|
||||
content = await get_content_from_response(response)
|
||||
log.debug(f"{content=}")
|
||||
|
||||
if not content:
|
||||
return body, {}
|
||||
|
||||
try:
|
||||
content = content[content.find("{") : content.rfind("}") + 1]
|
||||
if not content:
|
||||
raise Exception("No JSON object found in the response")
|
||||
|
||||
result = json.loads(content)
|
||||
|
||||
tool_function_name = result.get("name", None)
|
||||
if tool_function_name not in tools:
|
||||
return body, {}
|
||||
|
||||
tool_function_params = result.get("parameters", {})
|
||||
|
||||
try:
|
||||
required_params = (
|
||||
tools[tool_function_name]
|
||||
.get("spec", {})
|
||||
.get("parameters", {})
|
||||
.get("required", [])
|
||||
)
|
||||
tool_function = tools[tool_function_name]["callable"]
|
||||
tool_function_params = {
|
||||
k: v
|
||||
for k, v in tool_function_params.items()
|
||||
if k in required_params
|
||||
}
|
||||
tool_output = await tool_function(**tool_function_params)
|
||||
|
||||
except Exception as e:
|
||||
tool_output = str(e)
|
||||
|
||||
if isinstance(tool_output, str):
|
||||
if tools[tool_function_name]["citation"]:
|
||||
sources.append(
|
||||
{
|
||||
"source": {
|
||||
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
||||
},
|
||||
"document": [tool_output],
|
||||
"metadata": [
|
||||
{
|
||||
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
sources.append(
|
||||
{
|
||||
"source": {},
|
||||
"document": [tool_output],
|
||||
"metadata": [
|
||||
{
|
||||
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
if tools[tool_function_name]["file_handler"]:
|
||||
skip_files = True
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error: {e}")
|
||||
content = None
|
||||
except Exception as e:
|
||||
log.exception(f"Error: {e}")
|
||||
content = None
|
||||
|
||||
log.debug(f"tool_contexts: {sources}")
|
||||
|
||||
if skip_files and "files" in body.get("metadata", {}):
|
||||
del body["metadata"]["files"]
|
||||
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
async def chat_completion_files_handler(
|
||||
request: Request, body: dict, user: UserModel
|
||||
) -> tuple[dict, dict[str, list]]:
|
||||
sources = []
|
||||
|
||||
if files := body.get("metadata", {}).get("files", None):
|
||||
try:
|
||||
queries_response = await generate_queries(
|
||||
{
|
||||
"model": body["model"],
|
||||
"messages": body["messages"],
|
||||
"type": "retrieval",
|
||||
},
|
||||
user,
|
||||
)
|
||||
queries_response = queries_response["choices"][0]["message"]["content"]
|
||||
|
||||
try:
|
||||
bracket_start = queries_response.find("{")
|
||||
bracket_end = queries_response.rfind("}") + 1
|
||||
|
||||
if bracket_start == -1 or bracket_end == -1:
|
||||
raise Exception("No JSON object found in the response")
|
||||
|
||||
queries_response = queries_response[bracket_start:bracket_end]
|
||||
queries_response = json.loads(queries_response)
|
||||
except Exception as e:
|
||||
queries_response = {"queries": [queries_response]}
|
||||
|
||||
queries = queries_response.get("queries", [])
|
||||
except Exception as e:
|
||||
queries = []
|
||||
|
||||
if len(queries) == 0:
|
||||
queries = [get_last_user_message(body["messages"])]
|
||||
|
||||
sources = get_sources_from_files(
|
||||
files=files,
|
||||
queries=queries,
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
k=request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
|
||||
log.debug(f"rag_contexts:sources: {sources}")
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not (
|
||||
request.method == "POST"
|
||||
and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
)
|
||||
):
|
||||
return await call_next(request)
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
await get_all_models(request)
|
||||
models = app.state.MODELS
|
||||
|
||||
async def get_body_and_model_and_user(request, models):
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
body_str = body.decode("utf-8")
|
||||
body = json.loads(body_str) if body_str else {}
|
||||
|
||||
model_id = body["model"]
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
model = models[model_id]
|
||||
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||
)
|
||||
|
||||
return body, model, user
|
||||
|
||||
try:
|
||||
body, model, user = await get_body_and_model_and_user(request, models)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
if model.get("arena"):
|
||||
if not has_access(
|
||||
user.id,
|
||||
type="read",
|
||||
access_control=model.get("info", {})
|
||||
.get("meta", {})
|
||||
.get("access_control", {}),
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
else:
|
||||
if not model_info:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={"detail": "Model not found"},
|
||||
)
|
||||
elif not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={"detail": "User does not have access to the model"},
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"chat_id": body.pop("chat_id", None),
|
||||
"message_id": body.pop("id", None),
|
||||
"session_id": body.pop("session_id", None),
|
||||
"tool_ids": body.get("tool_ids", None),
|
||||
"files": body.get("files", None),
|
||||
}
|
||||
body["metadata"] = metadata
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": get_event_emitter(metadata),
|
||||
"__event_call__": get_event_call(metadata),
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
}
|
||||
|
||||
# Initialize data_items to store additional data to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
data_items = []
|
||||
sources = []
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_filter_functions_handler(
|
||||
body, model, extra_params
|
||||
)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
tool_ids = body.pop("tool_ids", None)
|
||||
files = body.pop("files", None)
|
||||
|
||||
metadata = {
|
||||
**metadata,
|
||||
"tool_ids": tool_ids,
|
||||
"files": files,
|
||||
}
|
||||
body["metadata"] = metadata
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_tools_handler(
|
||||
request, body, user, models, extra_params
|
||||
)
|
||||
sources.extend(flags.get("sources", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_files_handler(request, body, user)
|
||||
sources.extend(flags.get("sources", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
# If context is not empty, insert it into the messages
|
||||
if len(sources) > 0:
|
||||
context_string = ""
|
||||
for source_idx, source in enumerate(sources):
|
||||
source_id = source.get("source", {}).get("name", "")
|
||||
|
||||
if "document" in source:
|
||||
for doc_idx, doc_context in enumerate(source["document"]):
|
||||
metadata = source.get("metadata")
|
||||
doc_source_id = None
|
||||
|
||||
if metadata:
|
||||
doc_source_id = metadata[doc_idx].get("source", source_id)
|
||||
|
||||
if source_id:
|
||||
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
else:
|
||||
# If there is no source_id, then do not include the source_id tag
|
||||
context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
|
||||
|
||||
context_string = context_string.strip()
|
||||
prompt = get_last_user_message(body["messages"])
|
||||
|
||||
if prompt is None:
|
||||
raise Exception("No user message found")
|
||||
if (
|
||||
app.state.config.RELEVANCE_THRESHOLD == 0
|
||||
and context_string.strip() == ""
|
||||
):
|
||||
log.debug(
|
||||
f"With a 0 relevancy threshold for RAG, the context cannot be empty"
|
||||
)
|
||||
|
||||
# Workaround for Ollama 2.0+ system prompt issue
|
||||
# TODO: replace with add_or_update_system_message
|
||||
if model["owned_by"] == "ollama":
|
||||
body["messages"] = prepend_to_first_user_message_content(
|
||||
rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
|
||||
body["messages"],
|
||||
)
|
||||
else:
|
||||
body["messages"] = add_or_update_system_message(
|
||||
rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
|
||||
body["messages"],
|
||||
)
|
||||
|
||||
# If there are citations, add them to the data_items
|
||||
sources = [
|
||||
source for source in sources if source.get("source", {}).get("name", "")
|
||||
]
|
||||
if len(sources) > 0:
|
||||
data_items.append({"sources": sources})
|
||||
|
||||
modified_body_bytes = json.dumps(body).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
request.headers.__dict__["_list"] = [
|
||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
||||
]
|
||||
|
||||
response = await call_next(request)
|
||||
if not isinstance(response, StreamingResponse):
|
||||
return response
|
||||
|
||||
content_type = response.headers["Content-Type"]
|
||||
is_openai = "text/event-stream" in content_type
|
||||
is_ollama = "application/x-ndjson" in content_type
|
||||
if not is_openai and not is_ollama:
|
||||
return response
|
||||
|
||||
def wrap_item(item):
|
||||
return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
||||
|
||||
async def stream_wrapper(original_generator, data_items):
|
||||
for item in data_items:
|
||||
yield wrap_item(json.dumps(item))
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator, data_items),
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
async def _receive(self, body: bytes):
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
|
||||
app.add_middleware(ChatCompletionMiddleware)
|
||||
|
||||
|
||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not (
|
||||
request.method == "POST"
|
||||
and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
)
|
||||
):
|
||||
return await call_next(request)
|
||||
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
# Decode body to string
|
||||
body_str = body.decode("utf-8")
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
try:
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(request.headers["Authorization"]),
|
||||
)
|
||||
except KeyError as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Not authenticated"},
|
||||
)
|
||||
except HTTPException as e:
|
||||
return JSONResponse(
|
||||
status_code=e.status_code,
|
||||
content={"detail": e.detail},
|
||||
)
|
||||
|
||||
await get_all_models(request)
|
||||
models = app.state.MODELS
|
||||
|
||||
try:
|
||||
data = process_pipeline_inlet_filter(request, data, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
request.headers.__dict__["_list"] = [
|
||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
||||
]
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
async def _receive(self, body: bytes):
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
|
||||
app.add_middleware(PipelineMiddleware)
|
||||
|
||||
|
||||
class RedirectMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
@@ -1471,8 +812,32 @@ async def chat_completion(
|
||||
user=Depends(get_verified_user),
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
|
||||
try:
|
||||
return await chat_completion_handler(request, form_data, user, bypass_filter)
|
||||
model_id = form_data.get("model", None)
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
model = request.app.state.MODELS[model_id]
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
form_data, events = await process_chat_payload(request, form_data, user, model)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await chat_completion_handler(
|
||||
request, form_data, user, bypass_filter
|
||||
)
|
||||
return await process_chat_response(response, events)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -1480,6 +845,7 @@ async def chat_completion(
|
||||
)
|
||||
|
||||
|
||||
# Alias for chat_completion (Legacy)
|
||||
generate_chat_completions = chat_completion
|
||||
generate_chat_completion = chat_completion
|
||||
|
||||
|
||||
Reference in New Issue
Block a user