Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def update_service_config(connection_manager, res):
blocked_uids=res.get("blockedUserIds", []),
bypassed_ips=res.get("allowedIPAddresses", []),
received_any_stats=res.get("receivedAnyStats", True),
excluded_uids_from_rate_limiting=res.get("excludedUserIdsFromRateLimiting", []),
)

# Handle outbound request blocking configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,48 @@ def test_update_service_config_block_new_outgoing_requests_only():
assert connection_manager.conf.outbound_domains == {
"existing.com": "allow"
} # Not changed


def test_update_service_config_excluded_user_ids_from_rate_limiting():
"""Test that excludedUserIdsFromRateLimiting is correctly applied to config"""
connection_manager = MagicMock()
connection_manager.conf = ServiceConfig(
endpoints=[],
last_updated_at=0,
blocked_uids=set(),
bypassed_ips=[],
received_any_stats=False,
)
connection_manager.block = False

res = {
"success": True,
"excludedUserIdsFromRateLimiting": ["user1", "user2"],
}

update_service_config(connection_manager, res)

assert connection_manager.conf.excluded_uids_from_rate_limiting == {
"user1",
"user2",
}


def test_update_service_config_excluded_user_ids_defaults_to_empty():
"""Test that excluded_uids_from_rate_limiting defaults to empty set when field is absent"""
connection_manager = MagicMock()
connection_manager.conf = ServiceConfig(
endpoints=[],
last_updated_at=0,
blocked_uids=set(),
bypassed_ips=[],
received_any_stats=False,
excluded_uids_from_rate_limiting=["user1"],
)
connection_manager.block = False

res = {"success": True}

update_service_config(connection_manager, res)

assert connection_manager.conf.excluded_uids_from_rate_limiting == set()
12 changes: 11 additions & 1 deletion aikido_zen/background_process/service_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ def __init__(
blocked_uids,
bypassed_ips,
received_any_stats: bool,
excluded_uids_from_rate_limiting=None,
):
# Init the class using update function :
self.update(
endpoints, last_updated_at, blocked_uids, bypassed_ips, received_any_stats
endpoints,
last_updated_at,
blocked_uids,
bypassed_ips,
received_any_stats,
excluded_uids_from_rate_limiting,
)
self.block_new_outgoing_requests = False
self.outbound_domains = {}
Expand All @@ -32,10 +38,14 @@ def update(
blocked_uids,
bypassed_ips,
received_any_stats: bool,
excluded_uids_from_rate_limiting=None,
):
self.last_updated_at = last_updated_at
self.received_any_stats = bool(received_any_stats)
self.blocked_uids = set(blocked_uids)
self.excluded_uids_from_rate_limiting = set(
excluded_uids_from_rate_limiting or []
)
self.set_endpoints(endpoints)
self.set_bypassed_ips(bypassed_ips)

Expand Down
46 changes: 46 additions & 0 deletions aikido_zen/background_process/service_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,49 @@ def test_service_config_with_empty_allowlist():
assert admin_endpoint["route"] == "/admin"
assert isinstance(admin_endpoint["allowedIPAddresses"], list)
assert len(admin_endpoint["allowedIPAddresses"]) == 0


def test_excluded_uids_from_rate_limiting_defaults_to_empty():
config = ServiceConfig(
endpoints=[],
last_updated_at=0,
blocked_uids=set(),
bypassed_ips=[],
received_any_stats=False,
)
assert config.excluded_uids_from_rate_limiting == set()


def test_excluded_uids_from_rate_limiting_stored_as_set():
config = ServiceConfig(
endpoints=[],
last_updated_at=0,
blocked_uids=set(),
bypassed_ips=[],
received_any_stats=False,
excluded_uids_from_rate_limiting=["user1", "user2"],
)
assert config.excluded_uids_from_rate_limiting == {"user1", "user2"}


def test_excluded_uids_from_rate_limiting_updated_via_update():
config = ServiceConfig(
endpoints=[],
last_updated_at=0,
blocked_uids=set(),
bypassed_ips=[],
received_any_stats=False,
excluded_uids_from_rate_limiting=["user1"],
)
assert "user1" in config.excluded_uids_from_rate_limiting

config.update(
endpoints=[],
last_updated_at=0,
blocked_uids=set(),
bypassed_ips=[],
received_any_stats=False,
excluded_uids_from_rate_limiting=["user2", "user3"],
)
assert config.excluded_uids_from_rate_limiting == {"user2", "user3"}
assert "user1" not in config.excluded_uids_from_rate_limiting
6 changes: 6 additions & 0 deletions aikido_zen/ratelimiting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def should_ratelimit_request(
max_requests = int(endpoint["rateLimiting"]["maxRequests"])
windows_size_in_ms = int(endpoint["rateLimiting"]["windowSizeInMS"])

if (
user
and user.get("id") in connection_manager.conf.excluded_uids_from_rate_limiting
):
return {"block": False}

if group:
allowed = connection_manager.rate_limiter.is_allowed(
get_key_for_group(endpoint, group),
Expand Down
76 changes: 75 additions & 1 deletion aikido_zen/ratelimiting/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ def user():
return {"id": "user123"}


def create_connection_manager(endpoints=[], bypassed_ips=[]):
def create_connection_manager(endpoints=[], bypassed_ips=[], excluded_uids=[]):
cm = MagicMock()
cm.conf = ServiceConfig(
endpoints=endpoints,
last_updated_at=1,
blocked_uids=[],
bypassed_ips=bypassed_ips,
received_any_stats=True,
excluded_uids_from_rate_limiting=excluded_uids,
)
cm.rate_limiter = RateLimiter(
max_items=5000, time_to_live_in_ms=120 * 60 * 1000 # 120 minutes
Expand Down Expand Up @@ -478,6 +479,79 @@ def test_works_with_multiple_rate_limit_groups_and_different_users():
}


def test_excluded_user_bypasses_user_rate_limit():
endpoint = {
"method": "POST",
"route": "/login",
"forceProtectionOff": False,
"rateLimiting": {
"enabled": True,
"maxRequests": 3,
"windowSizeInMS": 1000,
},
}
cm = create_connection_manager([endpoint], excluded_uids=["user123"])
route_metadata = create_route_metadata()

# Excluded user should never be blocked, even past maxRequests
for _ in range(5):
assert should_ratelimit_request(
route_metadata, "1.2.3.4", {"id": "user123"}, cm
) == {"block": False}


def test_non_excluded_user_still_rate_limited():
endpoint = {
"method": "POST",
"route": "/login",
"forceProtectionOff": False,
"rateLimiting": {
"enabled": True,
"maxRequests": 3,
"windowSizeInMS": 1000,
},
}
cm = create_connection_manager([endpoint], excluded_uids=["other_user"])
route_metadata = create_route_metadata()

assert should_ratelimit_request(
route_metadata, "1.2.3.4", {"id": "user123"}, cm
) == {"block": False}
assert should_ratelimit_request(
route_metadata, "1.2.3.4", {"id": "user123"}, cm
) == {"block": False}
assert should_ratelimit_request(
route_metadata, "1.2.3.4", {"id": "user123"}, cm
) == {"block": False}
assert should_ratelimit_request(
route_metadata, "1.2.3.4", {"id": "user123"}, cm
) == {
"block": True,
"trigger": "user",
}


def test_excluded_user_bypasses_group_rate_limit():
endpoint = {
"method": "POST",
"route": "/login",
"forceProtectionOff": False,
"rateLimiting": {
"enabled": True,
"maxRequests": 3,
"windowSizeInMS": 1000,
},
}
cm = create_connection_manager([endpoint], excluded_uids=["user123"])
route_metadata = create_route_metadata()

# Excluded user should never be blocked, even past maxRequests, even with a group set
for _ in range(5):
assert should_ratelimit_request(
route_metadata, "1.2.3.4", {"id": "user123"}, cm, "group1"
) == {"block": False}


def test_rate_limits_by_group_if_user_is_not_set():
cm = create_connection_manager(
[
Expand Down
1 change: 1 addition & 0 deletions aikido_zen/thread/thread_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def reset(self):
bypassed_ips=[],
last_updated_at=-1,
received_any_stats=False,
excluded_uids_from_rate_limiting=set(),
)
self.middleware_installed = False
self.hostnames.clear()
Expand Down
Loading