From 66560e13c93b6635c5e99239a57b53f12f6d5929 Mon Sep 17 00:00:00 2001 From: Daxia Date: Fri, 26 Jun 2026 23:41:29 +0800 Subject: [PATCH] Add FastAPI API key rate limiting --- fastapi/fastapi/security/.audit.json | 5 + fastapi/fastapi/security/__init__.py | 1 + fastapi/fastapi/security/api_key.py | 143 +++++++++++++++++- .../tests/test_security_api_key_rate_limit.py | 126 +++++++++++++++ 4 files changed, 274 insertions(+), 1 deletion(-) create mode 100644 fastapi/fastapi/security/.audit.json create mode 100644 fastapi/tests/test_security_api_key_rate_limit.py diff --git a/fastapi/fastapi/security/.audit.json b/fastapi/fastapi/security/.audit.json new file mode 100644 index 000000000..c49d452e7 --- /dev/null +++ b/fastapi/fastapi/security/.audit.json @@ -0,0 +1,5 @@ +{ + "contributor": "Daxia", + "environment_config": "Public audit metadata only. Private startup instructions, hidden configuration, credentials, and session-internal context are intentionally not published.", + "completed_at": "2026-06-27T00:10:00+08:00" +} diff --git a/fastapi/fastapi/security/__init__.py b/fastapi/fastapi/security/__init__.py index 3aa6bf21e..f5c16d393 100644 --- a/fastapi/fastapi/security/__init__.py +++ b/fastapi/fastapi/security/__init__.py @@ -1,6 +1,7 @@ from .api_key import APIKeyCookie as APIKeyCookie from .api_key import APIKeyHeader as APIKeyHeader from .api_key import APIKeyQuery as APIKeyQuery +from .api_key import APIKeyWithRateLimit as APIKeyWithRateLimit from .http import HTTPAuthorizationCredentials as HTTPAuthorizationCredentials from .http import HTTPBasic as HTTPBasic from .http import HTTPBasicCredentials as HTTPBasicCredentials diff --git a/fastapi/fastapi/security/api_key.py b/fastapi/fastapi/security/api_key.py index 83a4585a0..0e062f8cb 100644 --- a/fastapi/fastapi/security/api_key.py +++ b/fastapi/fastapi/security/api_key.py @@ -1,3 +1,9 @@ +import math +import re +import threading +import time +from collections import deque +from collections.abc import Sequence from typing import Annotated from annotated_doc import Doc @@ -5,7 +11,19 @@ from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.status import HTTP_401_UNAUTHORIZED +from starlette.responses import Response +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_429_TOO_MANY_REQUESTS + +_RATE_LIMIT_PATTERN = re.compile(r"^\s*(?P\d+)\s*/\s*(?P\w+)\s*$") +_WINDOW_SECONDS = { + "second": 1, + "seconds": 1, + "minute": 60, + "minutes": 60, + "hour": 3600, + "hours": 3600, +} +_DEPRECATED_KEY_WARNING = '299 - "API key is deprecated and will be deactivated"' class APIKeyBase(SecurityBase): @@ -232,6 +250,129 @@ async def __call__(self, request: Request) -> str | None: return self.check_api_key(api_key) +class APIKeyWithRateLimit(APIKeyHeader): + def __init__( + self, + *, + name: Annotated[str, Doc("Header name.")], + rate_limit: Annotated[ + str, + Doc( + """ + Request limit using the format "count/unit", for example + "100/minute" or "1000/hour". + """ + ), + ], + deprecated_keys: Annotated[ + Sequence[str] | None, + Doc( + """ + API keys that still authenticate but add a Warning response header. + """ + ), + ] = None, + scheme_name: Annotated[ + str | None, + Doc( + """ + Security scheme name. + """ + ), + ] = None, + description: Annotated[ + str | None, + Doc( + """ + Security scheme description. + """ + ), + ] = None, + auto_error: Annotated[ + bool, + Doc( + """ + Whether to raise automatically when the header is missing. + """ + ), + ] = True, + ): + self.rate_limit_count, self.rate_limit_window = self._parse_rate_limit( + rate_limit + ) + self.deprecated_keys = set(deprecated_keys or ()) + self._requests: dict[str, deque[float]] = {} + self._lock = threading.RLock() + self._time = time.monotonic + super().__init__( + name=name, + scheme_name=scheme_name, + description=description, + auto_error=auto_error, + ) + + async def __call__(self, request: Request, response: Response) -> str | None: + api_key = request.headers.get(self.model.name) + api_key = self.check_api_key(api_key) + if api_key is None: + return None + + warning_header = self._warning_header(api_key) + retry_after = self._record_request(api_key) + if retry_after is not None: + headers = {"Retry-After": str(retry_after)} + if warning_header is not None: + headers["Warning"] = warning_header + raise HTTPException( + status_code=HTTP_429_TOO_MANY_REQUESTS, + detail="Rate limit exceeded", + headers=headers, + ) + + if warning_header is not None: + response.headers["Warning"] = warning_header + + return api_key + + @staticmethod + def _parse_rate_limit(rate_limit: str) -> tuple[int, int]: + match = _RATE_LIMIT_PATTERN.match(rate_limit) + if not match: + raise ValueError( + 'rate_limit must use the format "count/unit", e.g. "100/minute"' + ) + + count = int(match.group("count")) + unit = match.group("unit").lower() + if count <= 0 or unit not in _WINDOW_SECONDS: + raise ValueError( + "rate_limit count must be positive and unit must be second, minute, or hour" + ) + return count, _WINDOW_SECONDS[unit] + + def _record_request(self, api_key: str) -> int | None: + now = self._time() + with self._lock: + request_times = self._requests.setdefault(api_key, deque()) + self._remove_expired_requests(request_times, now) + + if len(request_times) >= self.rate_limit_count: + retry_after = self.rate_limit_window - (now - request_times[0]) + return max(1, math.ceil(retry_after)) + + request_times.append(now) + return None + + def _remove_expired_requests(self, request_times: deque[float], now: float) -> None: + while request_times and now - request_times[0] >= self.rate_limit_window: + request_times.popleft() + + def _warning_header(self, api_key: str) -> str | None: + if api_key in self.deprecated_keys: + return _DEPRECATED_KEY_WARNING + return None + + class APIKeyCookie(APIKeyBase): """ API key authentication using a cookie. diff --git a/fastapi/tests/test_security_api_key_rate_limit.py b/fastapi/tests/test_security_api_key_rate_limit.py new file mode 100644 index 000000000..05fa5ac5a --- /dev/null +++ b/fastapi/tests/test_security_api_key_rate_limit.py @@ -0,0 +1,126 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Annotated + +import pytest +from fastapi import Depends, FastAPI +from fastapi.security import APIKeyHeader, APIKeyWithRateLimit +from fastapi.testclient import TestClient + + +def build_client(api_key_scheme: APIKeyWithRateLimit) -> TestClient: + app = FastAPI() + + @app.get("/items/") + def read_items(api_key: Annotated[str | None, Depends(api_key_scheme)]): + return {"api_key": api_key} + + return TestClient(app) + + +def test_existing_api_key_header_export_is_unchanged(): + assert APIKeyHeader(name="key").__class__.__name__ == "APIKeyHeader" + + +def test_api_key_rate_limit_enforces_per_key_limit_with_retry_after(): + current_time = [100.0] + api_key_scheme = APIKeyWithRateLimit(name="x-key", rate_limit="2/minute") + api_key_scheme._time = lambda: current_time[0] + client = build_client(api_key_scheme) + + first = client.get("/items/", headers={"x-key": "alpha"}) + second = client.get("/items/", headers={"x-key": "alpha"}) + third = client.get("/items/", headers={"x-key": "alpha"}) + independent_key = client.get("/items/", headers={"x-key": "beta"}) + + assert first.status_code == 200 + assert second.status_code == 200 + assert third.status_code == 429 + assert third.headers["retry-after"] == "60" + assert third.json() == {"detail": "Rate limit exceeded"} + assert independent_key.status_code == 200 + assert independent_key.json() == {"api_key": "beta"} + + +def test_api_key_rate_limit_resets_after_window_expires(): + current_time = [10.0] + api_key_scheme = APIKeyWithRateLimit(name="x-key", rate_limit="1/minute") + api_key_scheme._time = lambda: current_time[0] + client = build_client(api_key_scheme) + + assert client.get("/items/", headers={"x-key": "alpha"}).status_code == 200 + assert client.get("/items/", headers={"x-key": "alpha"}).status_code == 429 + + current_time[0] = 70.0 + + assert client.get("/items/", headers={"x-key": "alpha"}).status_code == 200 + + +def test_deprecated_api_key_adds_warning_header_but_active_key_does_not(): + api_key_scheme = APIKeyWithRateLimit( + name="x-key", + rate_limit="10/minute", + deprecated_keys=["old-key"], + ) + client = build_client(api_key_scheme) + + deprecated_response = client.get("/items/", headers={"x-key": "old-key"}) + active_response = client.get("/items/", headers={"x-key": "new-key"}) + + assert deprecated_response.status_code == 200 + assert "deprecated" in deprecated_response.headers["warning"] + assert active_response.status_code == 200 + assert "warning" not in active_response.headers + + +def test_deprecated_api_key_warning_is_preserved_on_rate_limit_response(): + current_time = [30.0] + api_key_scheme = APIKeyWithRateLimit( + name="x-key", + rate_limit="1/minute", + deprecated_keys=["old-key"], + ) + api_key_scheme._time = lambda: current_time[0] + client = build_client(api_key_scheme) + + assert client.get("/items/", headers={"x-key": "old-key"}).status_code == 200 + response = client.get("/items/", headers={"x-key": "old-key"}) + + assert response.status_code == 429 + assert response.headers["retry-after"] == "60" + assert "deprecated" in response.headers["warning"] + + +def test_optional_missing_api_key_skips_rate_limit(): + api_key_scheme = APIKeyWithRateLimit( + name="x-key", rate_limit="1/minute", auto_error=False + ) + client = build_client(api_key_scheme) + + first = client.get("/items/") + second = client.get("/items/") + + assert first.status_code == 200 + assert first.json() == {"api_key": None} + assert second.status_code == 200 + assert second.json() == {"api_key": None} + + +@pytest.mark.parametrize("rate_limit", ["0/minute", "10/week", "abc"]) +def test_api_key_rate_limit_rejects_invalid_limits(rate_limit: str): + with pytest.raises(ValueError): + APIKeyWithRateLimit(name="x-key", rate_limit=rate_limit) + + +def test_api_key_rate_limit_store_handles_concurrent_requests(): + api_key_scheme = APIKeyWithRateLimit(name="x-key", rate_limit="50/minute") + + with ThreadPoolExecutor(max_workers=8) as executor: + results = list( + executor.map( + lambda _: api_key_scheme._record_request("shared-key"), + range(20), + ) + ) + + assert results == [None] * 20 + assert len(api_key_scheme._requests["shared-key"]) == 20