mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
wip
This commit is contained in:
@@ -30,6 +30,130 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Pipeline Middleware
|
||||
#
|
||||
##################################
|
||||
|
||||
|
||||
def get_sorted_filters(model_id, models):
|
||||
filters = [
|
||||
model
|
||||
for model in models.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
and (
|
||||
model["pipeline"]["pipelines"] == ["*"]
|
||||
or any(
|
||||
model_id == target_model_id
|
||||
for target_model_id in model["pipeline"]["pipelines"]
|
||||
)
|
||||
)
|
||||
]
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
return sorted_filters
|
||||
|
||||
|
||||
def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key == "":
|
||||
continue
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
raise Exception(r.status_code, res["detail"])
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers={"Authorization": f"Bearer {key}"},
|
||||
json={
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
},
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
return Exception(r.status_code, res)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Pipelines Endpoints
|
||||
@@ -39,7 +163,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/api/pipelines/list")
|
||||
@router.get("/list")
|
||||
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
||||
responses = await get_all_models_responses(request)
|
||||
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
|
||||
@@ -61,7 +185,7 @@ async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
||||
}
|
||||
|
||||
|
||||
@router.post("/api/pipelines/upload")
|
||||
@router.post("/upload")
|
||||
async def upload_pipeline(
|
||||
request: Request,
|
||||
urlIdx: int = Form(...),
|
||||
@@ -131,7 +255,7 @@ class AddPipelineForm(BaseModel):
|
||||
urlIdx: int
|
||||
|
||||
|
||||
@router.post("/api/pipelines/add")
|
||||
@router.post("/add")
|
||||
async def add_pipeline(
|
||||
request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
|
||||
):
|
||||
@@ -176,7 +300,7 @@ class DeletePipelineForm(BaseModel):
|
||||
urlIdx: int
|
||||
|
||||
|
||||
@router.delete("/api/pipelines/delete")
|
||||
@router.delete("/delete")
|
||||
async def delete_pipeline(
|
||||
request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
|
||||
):
|
||||
@@ -216,7 +340,7 @@ async def delete_pipeline(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/pipelines")
|
||||
@router.get("/")
|
||||
async def get_pipelines(
|
||||
request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
|
||||
):
|
||||
@@ -250,7 +374,7 @@ async def get_pipelines(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/pipelines/{pipeline_id}/valves")
|
||||
@router.get("/{pipeline_id}/valves")
|
||||
async def get_pipeline_valves(
|
||||
request: Request,
|
||||
urlIdx: Optional[int],
|
||||
@@ -289,7 +413,7 @@ async def get_pipeline_valves(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/pipelines/{pipeline_id}/valves/spec")
|
||||
@router.get("/{pipeline_id}/valves/spec")
|
||||
async def get_pipeline_valves_spec(
|
||||
request: Request,
|
||||
urlIdx: Optional[int],
|
||||
@@ -329,7 +453,7 @@ async def get_pipeline_valves_spec(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/pipelines/{pipeline_id}/valves/update")
|
||||
@router.post("/{pipeline_id}/valves/update")
|
||||
async def update_pipeline_valves(
|
||||
request: Request,
|
||||
urlIdx: Optional[int],
|
||||
|
||||
Reference in New Issue
Block a user