Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions fastapi/fastapi/security/.audit.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"contributor": "Daxia",
"environment_config": "Public audit metadata only. Private startup instructions, hidden configuration, credentials, and session-internal context are intentionally not published.",
"completed_at": "2026-06-27T00:10:00+08:00"
}
1 change: 1 addition & 0 deletions fastapi/fastapi/security/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .api_key import APIKeyCookie as APIKeyCookie
from .api_key import APIKeyHeader as APIKeyHeader
from .api_key import APIKeyQuery as APIKeyQuery
from .api_key import APIKeyWithRateLimit as APIKeyWithRateLimit
from .http import HTTPAuthorizationCredentials as HTTPAuthorizationCredentials
from .http import HTTPBasic as HTTPBasic
from .http import HTTPBasicCredentials as HTTPBasicCredentials
Expand Down
143 changes: 142 additions & 1 deletion fastapi/fastapi/security/api_key.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
import math
import re
import threading
import time
from collections import deque
from collections.abc import Sequence
from typing import Annotated

from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED
from starlette.responses import Response
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_429_TOO_MANY_REQUESTS

_RATE_LIMIT_PATTERN = re.compile(r"^\s*(?P<count>\d+)\s*/\s*(?P<unit>\w+)\s*$")
_WINDOW_SECONDS = {
"second": 1,
"seconds": 1,
"minute": 60,
"minutes": 60,
"hour": 3600,
"hours": 3600,
}
_DEPRECATED_KEY_WARNING = '299 - "API key is deprecated and will be deactivated"'


class APIKeyBase(SecurityBase):
Expand Down Expand Up @@ -232,6 +250,129 @@ async def __call__(self, request: Request) -> str | None:
return self.check_api_key(api_key)


class APIKeyWithRateLimit(APIKeyHeader):
def __init__(
self,
*,
name: Annotated[str, Doc("Header name.")],
rate_limit: Annotated[
str,
Doc(
"""
Request limit using the format "count/unit", for example
"100/minute" or "1000/hour".
"""
),
],
deprecated_keys: Annotated[
Sequence[str] | None,
Doc(
"""
API keys that still authenticate but add a Warning response header.
"""
),
] = None,
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
Whether to raise automatically when the header is missing.
"""
),
] = True,
):
self.rate_limit_count, self.rate_limit_window = self._parse_rate_limit(
rate_limit
)
self.deprecated_keys = set(deprecated_keys or ())
self._requests: dict[str, deque[float]] = {}
self._lock = threading.RLock()
self._time = time.monotonic
super().__init__(
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)

async def __call__(self, request: Request, response: Response) -> str | None:
api_key = request.headers.get(self.model.name)
api_key = self.check_api_key(api_key)
if api_key is None:
return None

warning_header = self._warning_header(api_key)
retry_after = self._record_request(api_key)
if retry_after is not None:
headers = {"Retry-After": str(retry_after)}
if warning_header is not None:
headers["Warning"] = warning_header
raise HTTPException(
status_code=HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded",
headers=headers,
)

if warning_header is not None:
response.headers["Warning"] = warning_header

return api_key

@staticmethod
def _parse_rate_limit(rate_limit: str) -> tuple[int, int]:
match = _RATE_LIMIT_PATTERN.match(rate_limit)
if not match:
raise ValueError(
'rate_limit must use the format "count/unit", e.g. "100/minute"'
)

count = int(match.group("count"))
unit = match.group("unit").lower()
if count <= 0 or unit not in _WINDOW_SECONDS:
raise ValueError(
"rate_limit count must be positive and unit must be second, minute, or hour"
)
return count, _WINDOW_SECONDS[unit]

def _record_request(self, api_key: str) -> int | None:
now = self._time()
with self._lock:
request_times = self._requests.setdefault(api_key, deque())
self._remove_expired_requests(request_times, now)

if len(request_times) >= self.rate_limit_count:
retry_after = self.rate_limit_window - (now - request_times[0])
return max(1, math.ceil(retry_after))

request_times.append(now)
return None

def _remove_expired_requests(self, request_times: deque[float], now: float) -> None:
while request_times and now - request_times[0] >= self.rate_limit_window:
request_times.popleft()

def _warning_header(self, api_key: str) -> str | None:
if api_key in self.deprecated_keys:
return _DEPRECATED_KEY_WARNING
return None


class APIKeyCookie(APIKeyBase):
"""
API key authentication using a cookie.
Expand Down
126 changes: 126 additions & 0 deletions fastapi/tests/test_security_api_key_rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Annotated

import pytest
from fastapi import Depends, FastAPI
from fastapi.security import APIKeyHeader, APIKeyWithRateLimit
from fastapi.testclient import TestClient


def build_client(api_key_scheme: APIKeyWithRateLimit) -> TestClient:
app = FastAPI()

@app.get("/items/")
def read_items(api_key: Annotated[str | None, Depends(api_key_scheme)]):
return {"api_key": api_key}

return TestClient(app)


def test_existing_api_key_header_export_is_unchanged():
assert APIKeyHeader(name="key").__class__.__name__ == "APIKeyHeader"


def test_api_key_rate_limit_enforces_per_key_limit_with_retry_after():
current_time = [100.0]
api_key_scheme = APIKeyWithRateLimit(name="x-key", rate_limit="2/minute")
api_key_scheme._time = lambda: current_time[0]
client = build_client(api_key_scheme)

first = client.get("/items/", headers={"x-key": "alpha"})
second = client.get("/items/", headers={"x-key": "alpha"})
third = client.get("/items/", headers={"x-key": "alpha"})
independent_key = client.get("/items/", headers={"x-key": "beta"})

assert first.status_code == 200
assert second.status_code == 200
assert third.status_code == 429
assert third.headers["retry-after"] == "60"
assert third.json() == {"detail": "Rate limit exceeded"}
assert independent_key.status_code == 200
assert independent_key.json() == {"api_key": "beta"}


def test_api_key_rate_limit_resets_after_window_expires():
current_time = [10.0]
api_key_scheme = APIKeyWithRateLimit(name="x-key", rate_limit="1/minute")
api_key_scheme._time = lambda: current_time[0]
client = build_client(api_key_scheme)

assert client.get("/items/", headers={"x-key": "alpha"}).status_code == 200
assert client.get("/items/", headers={"x-key": "alpha"}).status_code == 429

current_time[0] = 70.0

assert client.get("/items/", headers={"x-key": "alpha"}).status_code == 200


def test_deprecated_api_key_adds_warning_header_but_active_key_does_not():
api_key_scheme = APIKeyWithRateLimit(
name="x-key",
rate_limit="10/minute",
deprecated_keys=["old-key"],
)
client = build_client(api_key_scheme)

deprecated_response = client.get("/items/", headers={"x-key": "old-key"})
active_response = client.get("/items/", headers={"x-key": "new-key"})

assert deprecated_response.status_code == 200
assert "deprecated" in deprecated_response.headers["warning"]
assert active_response.status_code == 200
assert "warning" not in active_response.headers


def test_deprecated_api_key_warning_is_preserved_on_rate_limit_response():
current_time = [30.0]
api_key_scheme = APIKeyWithRateLimit(
name="x-key",
rate_limit="1/minute",
deprecated_keys=["old-key"],
)
api_key_scheme._time = lambda: current_time[0]
client = build_client(api_key_scheme)

assert client.get("/items/", headers={"x-key": "old-key"}).status_code == 200
response = client.get("/items/", headers={"x-key": "old-key"})

assert response.status_code == 429
assert response.headers["retry-after"] == "60"
assert "deprecated" in response.headers["warning"]


def test_optional_missing_api_key_skips_rate_limit():
api_key_scheme = APIKeyWithRateLimit(
name="x-key", rate_limit="1/minute", auto_error=False
)
client = build_client(api_key_scheme)

first = client.get("/items/")
second = client.get("/items/")

assert first.status_code == 200
assert first.json() == {"api_key": None}
assert second.status_code == 200
assert second.json() == {"api_key": None}


@pytest.mark.parametrize("rate_limit", ["0/minute", "10/week", "abc"])
def test_api_key_rate_limit_rejects_invalid_limits(rate_limit: str):
with pytest.raises(ValueError):
APIKeyWithRateLimit(name="x-key", rate_limit=rate_limit)


def test_api_key_rate_limit_store_handles_concurrent_requests():
api_key_scheme = APIKeyWithRateLimit(name="x-key", rate_limit="50/minute")

with ThreadPoolExecutor(max_workers=8) as executor:
results = list(
executor.map(
lambda _: api_key_scheme._record_request("shared-key"),
range(20),
)
)

assert results == [None] * 20
assert len(api_key_scheme._requests["shared-key"]) == 20
Loading