chore: format

This commit is contained in:
Timothy J. Baek
2024-10-20 18:38:06 -07:00
parent 768b7e139c
commit 9936583477
6 changed files with 85 additions and 72 deletions

View File

@@ -208,8 +208,6 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
app.state.MODELS = {}
##################################
#
# ChatCompletion Middleware
@@ -223,14 +221,14 @@ def get_task_model_id(default_model_id):
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
@@ -367,7 +365,7 @@ async def get_content_from_response(response) -> Optional[str]:
async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict
body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions
metadata = body.get("metadata", {})
@@ -681,15 +679,15 @@ def get_sorted_filters(model_id):
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
return sorted_filters
@@ -875,8 +873,8 @@ async def update_embedding_function(request: Request, call_next):
@app.middleware("http")
async def inspect_websocket(request: Request, call_next):
if (
"/ws/socket.io" in request.url.path
and request.query_params.get("transport") == "websocket"
"/ws/socket.io" in request.url.path
and request.query_params.get("transport") == "websocket"
):
upgrade = (request.headers.get("Upgrade") or "").lower()
connection = (request.headers.get("Connection") or "").lower().split(",")
@@ -945,8 +943,8 @@ async def get_all_models():
if custom_model.base_model_id is None:
for model in models:
if (
custom_model.id == model["id"]
or custom_model.id == model["id"].split(":")[0]
custom_model.id == model["id"]
or custom_model.id == model["id"].split(":")[0]
):
model["name"] = custom_model.name
model["info"] = custom_model.model_dump()
@@ -963,8 +961,8 @@ async def get_all_models():
for model in models:
if (
custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0]
custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0]
):
owned_by = model["owned_by"]
if "pipe" in model:
@@ -1840,7 +1838,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
@app.post("/api/pipelines/upload")
async def upload_pipeline(
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
):
print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file
@@ -2017,9 +2015,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves(
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
):
r = None
try:
@@ -2055,9 +2053,9 @@ async def get_pipeline_valves(
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec(
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
):
r = None
try:
@@ -2092,10 +2090,10 @@ async def get_pipeline_valves_spec(
@app.post("/api/pipelines/{pipeline_id}/valves/update")
async def update_pipeline_valves(
urlIdx: Optional[int],
pipeline_id: str,
form_data: dict,
user=Depends(get_admin_user),
urlIdx: Optional[int],
pipeline_id: str,
form_data: dict,
user=Depends(get_admin_user),
):
r = None
try:
@@ -2219,7 +2217,7 @@ class ModelFilterConfigForm(BaseModel):
@app.post("/api/config/model/filter")
async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
app.state.config.MODEL_FILTER_LIST = form_data.models
@@ -2274,7 +2272,7 @@ async def get_app_latest_release_version():
timeout = aiohttp.ClientTimeout(total=1)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
) as response:
response.raise_for_status()
data = await response.json()