refac: decouple api key restrictions from get user

This commit is contained in:
Timothy Jaeryang Baek
2025-11-13 19:52:04 -05:00
parent e2ff2ae252
commit b160eef7eb
2 changed files with 41 additions and 23 deletions

View File

@@ -1218,6 +1218,10 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
app.state.MODELS = {}
# Add the middleware to the app
if ENABLE_COMPRESSION_MIDDLEWARE:
app.add_middleware(CompressMiddleware)
class RedirectMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
@@ -1259,14 +1263,47 @@ class RedirectMiddleware(BaseHTTPMiddleware):
return response
# Add the middleware to the app
if ENABLE_COMPRESSION_MIDDLEWARE:
app.add_middleware(CompressMiddleware)
app.add_middleware(RedirectMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
auth_header = request.headers.get("Authorization")
# Only apply restrictions if an sk- API key is used
if auth_header and auth_header.startswith("sk-"):
# Check if restrictions are enabled
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
allowed_paths = [
path.strip()
for path in str(
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
).split(",")
if path.strip()
]
request_path = request.url.path
# Match exact path or prefix path
is_allowed = any(
request_path == allowed or request_path.startswith(allowed + "/")
for allowed in allowed_paths
)
if not is_allowed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="API key not allowed to access this endpoint.",
)
response = await call_next(request)
return response
app.add_middleware(APIKeyRestrictionMiddleware)
@app.middleware("http")
async def commit_session_after_request(request: Request, call_next):
response = await call_next(request)