From a5a449680086da76d3f435e2713d3e2abb00ec22 Mon Sep 17 00:00:00 2001 From: vamsi Date: Wed, 14 Aug 2024 17:41:40 +0530 Subject: [PATCH] fix: adding throttling at base api view for external apis --- apiserver/plane/api/views/base.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/apiserver/plane/api/views/base.py b/apiserver/plane/api/views/base.py index fee508a30c..a3241eaf3b 100644 --- a/apiserver/plane/api/views/base.py +++ b/apiserver/plane/api/views/base.py @@ -7,6 +7,7 @@ from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.db import IntegrityError from django.urls import resolve from django.utils import timezone +from plane.db.models.api import APIToken from rest_framework import status from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response @@ -16,7 +17,7 @@ from rest_framework.views import APIView # Module imports from plane.api.middleware.api_authentication import APIKeyAuthentication -from plane.api.rate_limit import ApiKeyRateThrottle +from plane.api.rate_limit import ApiKeyRateThrottle, ServiceTokenRateThrottle from plane.utils.exception_logger import log_exception from plane.utils.paginator import BasePaginator @@ -44,15 +45,29 @@ class BaseAPIView(TimezoneMixin, APIView, BasePaginator): IsAuthenticated, ] - throttle_classes = [ - ApiKeyRateThrottle, - ] - def filter_queryset(self, queryset): for backend in list(self.filter_backends): queryset = backend().filter_queryset(self.request, queryset, self) return queryset + def get_throttles(self): + throttle_classes = [] + api_key = self.request.headers.get("X-Api-Key") + + if api_key: + service_token = APIToken.objects.filter( + token=api_key, + is_service=True, + ).first() + + if service_token: + throttle_classes.append(ServiceTokenRateThrottle()) + return throttle_classes + + throttle_classes.append(ApiKeyRateThrottle()) + + return throttle_classes + def handle_exception(self, exc): """ Handle any exception that occurs, by returning an appropriate response, @@ -152,4 +167,4 @@ class BaseAPIView(TimezoneMixin, APIView, BasePaginator): for expand in self.request.GET.get("expand", "").split(",") if expand ] - return expand if expand else None + return expand if expand else None \ No newline at end of file