Files
Iliyan Angelov c67067a2a4 Mail
2025-09-14 23:24:25 +03:00

673 lines
21 KiB
Python

from functools import partial
from django.core.cache import cache, InvalidCacheBackendError
from django.core.exceptions import ImproperlyConfigured
from django.test import RequestFactory, TestCase
from django.test.utils import override_settings
from django.utils.decorators import method_decorator
from django.views.generic import View
from django_ratelimit.decorators import ratelimit
from django_ratelimit.exceptions import Ratelimited
from django_ratelimit.core import (get_usage, is_ratelimited,
_split_rate, _get_ip)
rf = RequestFactory()
class MockUser:
def __init__(self, authenticated=False):
self.pk = 1
self.is_authenticated = authenticated
class RateParsingTests(TestCase):
def test_simple(self):
tests = (
('100/s', (100, 1)),
('100/10s', (100, 10)),
('100/10', (100, 10)),
('100/m', (100, 60)),
('400/10m', (400, 600)),
('1000/h', (1000, 3600)),
('800/d', (800, 24 * 60 * 60)),
)
for i, o in tests:
assert o == _split_rate(i)
def callable_rate(group, request):
if request.user.is_authenticated:
return None
return (0, 1)
def mykey(group, request):
return request.META['REMOTE_ADDR'][::-1]
class CustomRatelimitedException(Exception):
pass
class RatelimitTests(TestCase):
def setUp(self):
cache.clear()
def test_no_key(self):
@ratelimit(rate='1/m')
def view(request):
return True
req = rf.get('/')
with self.assertRaises(ImproperlyConfigured):
view(req)
def test_ip(self):
@ratelimit(key='ip', rate='1/m', block=False)
def view(request):
return request.limited
assert not view(rf.get('/')), 'First request works.'
assert view(rf.get('/')), 'Second request is limited'
def test_block(self):
@ratelimit(key='ip', rate='1/m')
def blocked(request):
return request.limited
assert not blocked(rf.get('/')), 'First request works.'
with self.assertRaises(Ratelimited):
blocked(rf.get('/')), 'Second request is blocked.'
def test_ratelimit_custom_string_exception_class(self):
@ratelimit(key='ip', rate='1/m')
def view(request):
return request.limited
with self.settings(
RATELIMIT_EXCEPTION_CLASS=(
"django_ratelimit.tests.CustomRatelimitedException"
)
):
req = rf.get("")
assert not view(req)
with self.assertRaises(CustomRatelimitedException):
view(req)
def test_ratelimit_custom_exception_class(self):
@ratelimit(key='ip', rate='1/m')
def view(request):
return request.limited
with self.settings(
RATELIMIT_EXCEPTION_CLASS=CustomRatelimitedException
):
req = rf.get("")
assert not view(req)
with self.assertRaises(CustomRatelimitedException):
view(req)
def test_method(self):
@ratelimit(key='ip', method='POST', rate='1/m', group='a', block=False)
def limit_post(request):
return request.limited
assert not limit_post(rf.post('/')), 'Do not limit first POST.'
assert limit_post(rf.post('/')), 'Limit second POST.'
assert not limit_post(rf.get('/')), 'Do not limit GET.'
def test_unsafe_methods(self):
@ratelimit(key='ip', method=ratelimit.UNSAFE, rate='0/m', block=False)
def limit_unsafe(request):
return request.limited
assert not limit_unsafe(rf.get('/'))
assert not limit_unsafe(rf.head('/'))
assert not limit_unsafe(rf.options('/'))
assert limit_unsafe(rf.delete('/'))
assert limit_unsafe(rf.post('/'))
assert limit_unsafe(rf.put('/'))
assert limit_unsafe(rf.patch('/'))
def test_key_get(self):
@ratelimit(key='get:foo', rate='1/m', method='GET', block=False)
def view(request):
return request.limited
assert not view(rf.get('/', {'foo': 'a'}))
assert view(rf.get('/', {'foo': 'a'}))
assert not view(rf.get('/', {'foo': 'b'}))
assert view(rf.get('/', {'foo': 'b'}))
def test_key_post(self):
@ratelimit(key='post:foo', rate='1/m', block=False)
def view(request):
return request.limited
assert not view(rf.post('/', {'foo': 'a'}))
assert view(rf.post('/', {'foo': 'a'}))
assert not view(rf.post('/', {'foo': 'b'}))
assert view(rf.post('/', {'foo': 'b'}))
def test_key_header(self):
def _req():
req = rf.post('/')
req.META['HTTP_X_REAL_IP'] = '1.2.3.4'
return req
@ratelimit(key='header:x-real-ip', rate='1/m', block=False)
@ratelimit(key='header:x-missing-header', rate='1/m', block=False)
def view(request):
return request.limited
assert not view(_req())
assert view(_req())
def test_rate(self):
@ratelimit(key='ip', rate='2/m', block=False)
def twice(request):
return request.limited
assert not twice(rf.post('/')), 'First request is not limited.'
assert not twice(rf.post('/')), 'Second request is not limited.'
assert twice(rf.post('/')), 'Third request is limited.'
def test_zero_rate(self):
@ratelimit(key='ip', rate='0/m', block=False)
def never(request):
return request.limited
assert never(rf.post('/'))
def test_none_rate(self):
@ratelimit(key='ip', rate=None, block=False)
def always(request):
return request.limited
assert not always(rf.post('/'))
assert not always(rf.post('/'))
assert not always(rf.post('/'))
assert not always(rf.post('/'))
assert not always(rf.post('/'))
assert not always(rf.post('/'))
assert not always(rf.post('/'))
def test_callable_rate(self):
def _req(auth):
req = rf.post('/')
req.user = MockUser(authenticated=auth)
return req
def get_rate(group, request):
if request.user.is_authenticated:
return (2, 60)
return (1, 60)
@ratelimit(key='user_or_ip', rate=get_rate, block=False)
def view(request):
return request.limited
assert not view(_req(auth=False))
assert view(_req(auth=False))
assert not view(_req(auth=True))
assert not view(_req(auth=True))
assert view(_req(auth=True))
def test_callable_rate_none(self):
def _req(never_limit=False):
req = rf.post('/')
req.never_limit = never_limit
return req
get_rate = lambda g, r: None if r.never_limit else '1/m'
@ratelimit(key='ip', rate=get_rate, block=False)
def view(request):
return request.limited
assert not view(_req())
assert view(_req())
assert not view(_req(never_limit=True))
assert not view(_req(never_limit=True))
def test_callable_rate_zero(self):
def _req(auth):
req = rf.post('/')
req.user = MockUser(authenticated=auth)
return req
def get_rate(group, request):
if request.user.is_authenticated:
return '1/m'
return '0/m'
@ratelimit(key='ip', rate=get_rate, block=False)
def view(request):
return request.limited
assert view(_req(auth=False))
assert not view(_req(auth=True))
assert view(_req(auth=True))
def test_callable_rate_import(self):
def _req(auth):
req = rf.post('/')
req.user = MockUser(authenticated=auth)
return req
@ratelimit(key='user_or_ip',
rate='django_ratelimit.tests.callable_rate',
block=False)
def view(request):
return request.limited
assert view(_req(auth=False))
assert not view(_req(auth=True))
def test_user_or_ip(self):
"""Allow custom functions to set cache keys."""
def _req(auth):
req = rf.post('/')
req.user = MockUser(authenticated=auth)
return req
@ratelimit(key='user_or_ip', rate='1/m', block=False)
def view(request):
return request.limited
assert not view(_req(auth=False))
assert view(_req(auth=False))
auth = rf.post('/')
auth.user = MockUser(authenticated=True)
assert not view(_req(auth=True))
assert view(_req(auth=True))
def test_callable_key_path(self):
@ratelimit(key='django_ratelimit.tests.mykey', rate='1/m', block=False)
def view(request):
return request.limited
assert not view(rf.post('/'))
assert view(rf.post('/'))
def test_callable_key(self):
@ratelimit(key=mykey, rate='1/m', block=False)
def view(request):
return request.limited
assert not view(rf.post('/'))
assert view(rf.post('/'))
def test_stacked_decorator(self):
"""Allow @ratelimit to be stacked."""
# Put the shorter one first and make sure the second one doesn't
# reset request.limited back to False.
@ratelimit(rate='1/m', block=False, key=lambda x, y: 'min')
@ratelimit(rate='10/d', block=False, key=lambda x, y: 'day')
def view(request):
return request.limited
assert not view(rf.post('/'))
assert view(rf.post('/'))
def test_stacked_methods(self):
"""Different methods should result in different counts."""
@ratelimit(rate='1/m', key='ip', method='GET', block=False)
@ratelimit(rate='1/m', key='ip', method='POST', block=False)
def view(request):
return request.limited
assert not view(rf.get('/'))
assert not view(rf.post('/'))
assert view(rf.get('/'))
assert view(rf.post('/'))
def test_sorted_methods(self):
"""Order of the methods shouldn't matter."""
@ratelimit(rate='1/m', key='ip', method=['GET', 'POST'],
group='a', block=False)
def get_post(request):
return request.limited
@ratelimit(rate='1/m', key='ip', method=['POST', 'GET'],
group='a', block=False)
def post_get(request):
return request.limited
assert not get_post(rf.get('/'))
assert post_get(rf.get('/'))
def test_ratelimit_full_mask_v4(self):
@ratelimit(rate='1/m', key='ip', block=False)
def view(request):
return request.limited
with self.settings(RATELIMIT_IPV4_MASK=32):
req = rf.get('/')
req.META['REMOTE_ADDR'] = '10.1.1.1'
assert not view(req)
assert view(req)
req = rf.get('/')
req.META['REMOTE_ADDR'] = '10.1.1.2'
assert not view(req)
def test_ratelimit_full_mask_v6(self):
@ratelimit(rate='1/m', key='ip', block=False)
def view(request):
return request.limited
with self.settings(RATELIMIT_IPV6_MASK=128):
req = rf.get('/')
req.META['REMOTE_ADDR'] = '2001:db8::1000'
assert not view(req)
assert view(req)
req = rf.get('/')
req.META['REMOTE_ADDR'] = '2001:db8::1001'
assert not view(req)
def test_ratelimit_mask_v4(self):
@ratelimit(rate='1/m', key='ip', block=False)
def view(request):
return request.limited
with self.settings(RATELIMIT_IPV4_MASK=16):
req = rf.get('/')
req.META['REMOTE_ADDR'] = '10.1.1.1'
assert not view(req)
assert view(req)
req = rf.get('/')
req.META['REMOTE_ADDR'] = '10.1.0.1'
assert view(req)
req = rf.get('/')
req.META['REMOTE_ADDR'] = '192.168.1.1'
assert not view(req)
def test_ratelimit_mask_v6(self):
@ratelimit(rate='1/m', key='ip', block=False)
def view(request):
return request.limited
with self.settings(RATELIMIT_IPV6_MASK=64):
req = rf.get('/')
req.META['REMOTE_ADDR'] = '2001:db8::1000'
assert not view(req)
assert view(req)
req = rf.get('/')
req.META['REMOTE_ADDR'] = '2001:db8::1001'
assert view(req)
req = rf.get('/')
req.META['REMOTE_ADDR'] = '2001:db9::1000'
assert not view(req)
class FunctionsTests(TestCase):
def setUp(self):
cache.clear()
def test_is_ratelimited(self):
not_increment = partial(is_ratelimited, increment=False, rate='1/m',
method=is_ratelimited.ALL, key='ip', group='a')
# Does not increment. Count still 0. Does not rate limit
# because 0 < 1.
assert not not_increment(rf.get('/'))
# Does not increment. Count still 1. Not limited because 1 > 1
# is false.
assert not not_increment(rf.get('/'))
def test_is_ratelimited_increment(self):
do_increment = partial(is_ratelimited, increment=True, rate='1/m',
method=is_ratelimited.ALL, key='ip', group='a')
# Increments. Does not rate limit because 0 < 1. Count now 1.
assert not do_increment(rf.get('/'))
# Count = 2, 2 > 1.
assert do_increment(rf.get('/'))
def test_get_usage(self):
_get_usage = partial(get_usage, method=get_usage.ALL, key='ip',
rate='1/m', group='a')
usage = _get_usage(rf.get('/'))
self.assertEqual(usage['count'], 0)
self.assertEqual(usage['limit'], 1)
self.assertLessEqual(usage['time_left'], 60)
self.assertFalse(usage['should_limit'])
def test_get_usage_increment(self):
_get_usage = partial(get_usage, method=get_usage.ALL, key='ip',
rate='1/m', group='a', increment=True)
_get_usage(rf.get('/'))
usage = _get_usage(rf.get('/'))
self.assertEqual(usage['count'], 2)
self.assertEqual(usage['limit'], 1)
self.assertLessEqual(usage['time_left'], 60)
self.assertTrue(usage['should_limit'])
def test_not_increment_after_increment(self):
_get_usage = partial(get_usage, method=get_usage.ALL, key='ip',
rate='1/m', group='a')
_get_usage(rf.get('/'), increment=True)
_get_usage(rf.get('/'), increment=True)
usage = _get_usage(rf.get('/'))
self.assertEqual(usage['count'], 2)
self.assertEqual(usage['limit'], 1)
self.assertLessEqual(usage['time_left'], 60)
self.assertTrue(usage['should_limit'])
def test_get_usage_called_without_group_or_fn(self):
with self.assertRaises(ImproperlyConfigured):
get_usage(rf.get('/'), key='ip')
class RatelimitCBVTests(TestCase):
def setUp(self):
cache.clear()
def test_method_decorator(self):
class TestView(View):
@method_decorator(ratelimit(key='ip', rate='1/m', block=False))
def post(self, request):
return request.limited
view = TestView.as_view()
assert not view(rf.post('/'))
assert view(rf.post('/'))
def test_class_decorator(self):
@method_decorator(ratelimit(key='ip', rate='1/m', block=False),
name='get')
class TestView(View):
def get(self, request):
return request.limited
view = TestView.as_view()
assert not view(rf.get('/'))
assert view(rf.get('/'))
def test_wrap_view(self):
class TestView(View):
def get(self, request):
return request.limited
view = TestView.as_view()
wrapped = ratelimit(key='ip', rate='1/m', block=False)(view)
assert not wrapped(rf.get('/'))
assert wrapped(rf.get('/'))
def test_methods_counted_separately(self):
class TestView(View):
@method_decorator(ratelimit(key='ip', rate='1/m',
method='GET', block=False))
def get(self, request):
return request.limited
@method_decorator(ratelimit(key='ip', rate='1/m',
method='POST', block=False))
def post(self, request):
return request.limited
view = TestView.as_view()
assert not view(rf.get('/'))
assert view(rf.get('/'))
assert not view(rf.post('/'))
def test_views_counted_separately(self):
class TestView(View):
@method_decorator(ratelimit(key='ip', rate='1/m',
method='GET', block=False))
def get(self, request):
return request.limited
class AnotherTestView(View):
@method_decorator(ratelimit(key='ip', rate='1/m',
method='GET', block=False))
def get(self, request):
return request.limited
test_view = TestView.as_view()
another_view = AnotherTestView.as_view()
assert not test_view(rf.get('/'))
assert test_view(rf.get('/'))
assert not another_view(rf.get('/'))
class CacheFailTests(TestCase):
@override_settings(RATELIMIT_USE_CACHE='fake-cache')
def test_bad_cache(self):
@ratelimit(key='ip', rate='1/m', block=False)
def view(request):
return request.limited
with self.assertRaises(InvalidCacheBackendError):
view(rf.post('/'))
@override_settings(RATELIMIT_USE_CACHE='connection-errors')
def test_limit_on_cache_connection_error(self):
@ratelimit(key='ip', rate='10/m', block=False)
def view(request):
return request.limited
assert view(rf.post('/'))
@override_settings(RATELIMIT_USE_CACHE='connection-errors',
RATELIMIT_FAIL_OPEN=True)
def test_fail_open_setting(self):
@ratelimit(key='ip', rate='1/m', block=False)
def view(request):
return request.limited
assert not view(rf.get('/'))
assert not view(rf.get('/'))
@override_settings(RATELIMIT_USE_CACHE='connection-errors')
def test_is_ratelimited_cache_connection_error_without_increment(self):
def not_increment(request):
return is_ratelimited(request, increment=False,
method=is_ratelimited.ALL, key='ip',
rate='1/m', group='a')
assert not not_increment(rf.get('/'))
assert not not_increment(rf.get('/'))
@override_settings(RATELIMIT_USE_CACHE='connection-errors')
def test_is_ratelimited_cache_connection_error_with_increment(self):
def do_increment(request):
return is_ratelimited(request, increment=True,
method=is_ratelimited.ALL, key='ip',
rate='1/m', group='a')
assert do_increment(rf.get('/'))
assert do_increment(rf.get('/'))
@override_settings(RATELIMIT_USE_CACHE='connection-errors-redis')
def test_is_ratelimited_cache_connection_error_with_increment_redis(self):
def do_increment(request):
return is_ratelimited(request, increment=True,
method=is_ratelimited.ALL, key='ip',
rate='1/m', group='a')
assert do_increment(rf.get('/'))
assert do_increment(rf.get('/'))
@override_settings(RATELIMIT_USE_CACHE='instant-expiration')
def test_cache_timeout(self):
@ratelimit(key='ip', rate='1/m')
def view(request):
return True
assert view(rf.get('/'))
assert view(rf.get('/'))
def my_ip(req):
return req.META['MY_THING']
class IpMetaTests(TestCase):
def test_default(self):
req = rf.get('/')
req.META['REMOTE_ADDR'] = '1.2.3.4'
assert '1.2.3.4' == _get_ip(req)
@override_settings(RATELIMIT_IP_META_KEY='fake')
def test_bad_config(self):
req = rf.get('/')
req.META['REMOTE_ADDR'] = '1.2.3.4'
with self.assertRaises(ImproperlyConfigured):
_get_ip(req)
@override_settings(RATELIMIT_IP_META_KEY='HTTP_X_CLIENT_IP')
def test_alternate_header(self):
req = rf.get('/')
req.META['REMOTE_ADDR'] = '1.2.3.4'
req.META['HTTP_X_CLIENT_IP'] = '5.6.7.8'
assert '5.6.7.8' == _get_ip(req)
@override_settings(RATELIMIT_IP_META_KEY='django_ratelimit.tests.my_ip')
def test_path_to_ip_key_callable(self):
req = rf.get('/')
req.META['REMOTE_ADDR'] = '1.2.3.4'
req.META['MY_THING'] = '5.6.7.8'
assert '5.6.7.8' == _get_ip(req)
@override_settings(RATELIMIT_IP_META_KEY=my_ip)
def test_callable_ip_key(self):
req = rf.get('/')
req.META['REMOTE_ADDR'] = '1.2.3.4'
req.META['MY_THING'] = '5.6.7.8'
assert '5.6.7.8' == _get_ip(req)
def test_empty_ip(self):
req = rf.get('/')
req.META['REMOTE_ADDR'] = ''
with self.assertRaises(ImproperlyConfigured):
_get_ip(req)