From a8ca6884566ad1ccef428aa15db2b9ef2e1c07e8 Mon Sep 17 00:00:00 2001 From: Sonny Bakker Date: Sat, 17 Oct 2020 21:24:44 +0200 Subject: [PATCH] Apply ratelimitting to search endpoint --- src/newsreader/conf/base.py | 4 ++ src/newsreader/core/throttling.py | 20 ++++++++++ src/newsreader/news/collection/endpoints.py | 4 ++ .../tests/endpoints/rule/list/tests.py | 40 +++++++++++++------ 4 files changed, 55 insertions(+), 13 deletions(-) create mode 100644 src/newsreader/core/throttling.py diff --git a/src/newsreader/conf/base.py b/src/newsreader/conf/base.py index d41f352..15af28f 100644 --- a/src/newsreader/conf/base.py +++ b/src/newsreader/conf/base.py @@ -228,6 +228,10 @@ REST_FRAMEWORK = { "newsreader.accounts.permissions.IsOwner", ), "DEFAULT_RENDERER_CLASSES": ("rest_framework.renderers.JSONRenderer",), + "DEFAULT_THROTTLE_RATES": { + "burst_search": "100/min", + "sustained_search": "2000/day", + }, } SWAGGER_SETTINGS = { diff --git a/src/newsreader/core/throttling.py b/src/newsreader/core/throttling.py new file mode 100644 index 0000000..f333d95 --- /dev/null +++ b/src/newsreader/core/throttling.py @@ -0,0 +1,20 @@ +from rest_framework.throttling import UserRateThrottle + + +class SearchThrottle(UserRateThrottle): + """ + Only applies throttling to requests with the search param + """ + + def allow_request(self, request, view): + if not "search" in request.GET.keys(): + return True + return super().allow_request(request, view) + + +class BurstSearchThrottle(SearchThrottle): + scope = "burst_search" + + +class SustainedSearchThrottle(SearchThrottle): + scope = "sustained_search" diff --git a/src/newsreader/news/collection/endpoints.py b/src/newsreader/news/collection/endpoints.py index 7b43465..2ac4963 100644 --- a/src/newsreader/news/collection/endpoints.py +++ b/src/newsreader/news/collection/endpoints.py @@ -8,6 +8,7 @@ from rest_framework.generics import ( from rest_framework.response import Response from newsreader.core.pagination import LargeResultSetPagination, ResultSetPagination +from newsreader.core.throttling import BurstSearchThrottle, SustainedSearchThrottle from newsreader.news.collection.models import CollectionRule from newsreader.news.collection.serializers import RuleSerializer from newsreader.news.core.filters import ReadFilter @@ -19,9 +20,12 @@ class ListRuleView(ListAPIView): queryset = CollectionRule.objects.all() serializer_class = RuleSerializer pagination_class = ResultSetPagination + filter_backends = [filters.SearchFilter] search_fields = ["name", "screen_name", "url"] + throttle_classes = [BurstSearchThrottle, SustainedSearchThrottle] + def get_queryset(self): user = self.request.user return self.queryset.filter(user=user).order_by("name", "screen_name") diff --git a/src/newsreader/news/collection/tests/endpoints/rule/list/tests.py b/src/newsreader/news/collection/tests/endpoints/rule/list/tests.py index e5ff2fb..d474541 100644 --- a/src/newsreader/news/collection/tests/endpoints/rule/list/tests.py +++ b/src/newsreader/news/collection/tests/endpoints/rule/list/tests.py @@ -1,14 +1,17 @@ import json +import time -from datetime import date, datetime, time -from unittest import skip +from datetime import datetime from urllib.parse import urlencode +from django.core.cache import cache from django.test import TestCase from django.urls import reverse import pytz +from freezegun import freeze_time + from newsreader.accounts.tests.factories import UserFactory from newsreader.news.collection.tests.factories import ( FeedFactory, @@ -210,9 +213,26 @@ class RuleListViewSearchTestCase(TestCase): self.assertEqual(response_data["results"][1]["id"], rules["foo"].pk) self.assertEqual(response_data["results"][2]["id"], rules["FooBar"].pk) - @skip("TODO") + @freeze_time("2020-10-30 14:00") def test_ratelimitting(self): - pass + # Trigger ratelimit + cache.set( + f"throttle_burst_search_{self.user.pk}", [time.time() for i in range(100)] + ) + + params = urlencode({"search": "foo"}) + url = reverse("api:news:collection:rules-list") + + response = self.client.get(f"{url}?{params}") + response_data = response.json() + + self.assertEqual(response.status_code, 429) + + message = response_data["detail"] + + self.assertIn("Request was throttled", message) + + cache.delete(f"throttle_burst_search_{self.user.pk}") class NestedRuleListViewTestCase(TestCase): @@ -357,23 +377,17 @@ class NestedRuleListViewTestCase(TestCase): FeedPostFactory( title="I'm the first post", rule=rule, - publication_date=datetime.combine( - date(2019, 5, 20), time(hour=16, minute=7, second=37), pytz.utc - ), + publication_date=datetime(2019, 5, 20, 16, 7, 37, tzinfo=pytz.utc), ), FeedPostFactory( title="I'm the second post", rule=rule, - publication_date=datetime.combine( - date(2019, 7, 20), time(hour=18, minute=7, second=37), pytz.utc - ), + publication_date=datetime(2019, 7, 20, 18, 7, 37, tzinfo=pytz.utc), ), FeedPostFactory( title="I'm the third post", rule=rule, - publication_date=datetime.combine( - date(2019, 7, 20), time(hour=16, minute=7, second=37), pytz.utc - ), + publication_date=datetime(2019, 7, 20, 16, 7, 37, tzinfo=pytz.utc), ), ]