feat: pipelines filter outlet

This commit is contained in:
Timothy J. Baek
2024-05-30 02:04:29 -07:00
parent d9ceb31674
commit ef8d84296e
3 changed files with 156 additions and 4 deletions

View File

@@ -141,7 +141,8 @@ class RAGMiddleware(BaseHTTPMiddleware):
return_citations = False
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
"/ollama/api/chat" in request.url.path
or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}")
@@ -229,7 +230,8 @@ app.add_middleware(RAGMiddleware)
class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
"/ollama/api/chat" in request.url.path
or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}")
@@ -308,6 +310,9 @@ class PipelineMiddleware(BaseHTTPMiddleware):
else:
pass
if "chat_id" in data:
del data["chat_id"]
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
@@ -464,6 +469,69 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models}
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
return data
@app.get("/api/pipelines/list")
async def get_pipelines_list(user=Depends(get_admin_user)):
responses = await get_openai_models(raw=True)