feat: direct connections integration

This commit is contained in:
Timothy Jaeryang Baek
2025-02-12 22:56:33 -08:00
parent 304ce2a14d
commit c83e68282d
6 changed files with 387 additions and 94 deletions

View File

@@ -900,20 +900,30 @@ async def chat_completion(
if not request.app.state.MODELS:
await get_all_models(request)
model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None)
try:
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id)
# Check if user has access to the model
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
try:
if not model_item.get("direct", False):
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id)
# Check if user has access to the model
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
else:
model = model_item
model_info = None
request.state.direct = True
request.state.model = model
metadata = {
"user_id": user.id,
@@ -925,6 +935,7 @@ async def chat_completion(
"features": form_data.get("features", None),
"variables": form_data.get("variables", None),
"model": model_info,
"direct": model_item.get("direct", False),
**(
{"function_calling": "native"}
if form_data.get("params", {}).get("function_calling") == "native"
@@ -936,6 +947,7 @@ async def chat_completion(
else {}
),
}
request.state.metadata = metadata
form_data["metadata"] = metadata
form_data, metadata, events = await process_chat_payload(
@@ -943,6 +955,7 @@ async def chat_completion(
)
except Exception as e:
log.debug(f"Error processing chat payload: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
@@ -971,6 +984,12 @@ async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_completed_handler(request, form_data, user)
except Exception as e:
raise HTTPException(
@@ -984,6 +1003,12 @@ async def chat_action(
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
):
try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_action_handler(request, action_id, form_data, user)
except Exception as e:
raise HTTPException(