Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions integration/tests/posit/test_sessions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Integration tests for ``posit.connect.sessions.Session``.

These tests exercise the real ``requests`` / ``urllib3`` / socket stack
against a local HTTP server, rather than monkey-patching ``HTTPAdapter.send``
the way unit tests with ``responses`` do. That lets us verify that the
redirect semantics we rely on (POST body preservation on 301/302,
``Authorization`` stripping on cross-origin hops, ``response.history``
population, ``TooManyRedirects`` on overflow) actually hold end-to-end.

The server is an in-process ``http.server`` instance on ``127.0.0.1`` — no
Connect instance is required, so these run in any environment.
"""

from __future__ import annotations

import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Callable, List, Tuple

import pytest
import requests

from posit.connect.sessions import Session


class _Recorder:
"""Shared state between the test and the request handler thread."""

def __init__(self) -> None:
self.requests: List[Tuple[str, str, dict, bytes]] = [] # method, path, headers, body
self.responder: Callable[[BaseHTTPRequestHandler], None] | None = None


def _make_handler(recorder: _Recorder):
class Handler(BaseHTTPRequestHandler):
def log_message(self, format, *args): # noqa: A002 - silence default stderr logging
pass

def _record(self) -> None:
length = int(self.headers.get("Content-Length") or 0)
body = self.rfile.read(length) if length else b""
recorder.requests.append(
(self.command, self.path, dict(self.headers), body),
)

def do_GET(self) -> None: # noqa: N802
self._record()
assert recorder.responder is not None
recorder.responder(self)

def do_POST(self) -> None: # noqa: N802
self._record()
assert recorder.responder is not None
recorder.responder(self)

return Handler


@pytest.fixture
def server():
recorder = _Recorder()
httpd = HTTPServer(("127.0.0.1", 0), _make_handler(recorder))
thread = threading.Thread(target=httpd.serve_forever, daemon=True)
thread.start()
host, port = httpd.server_address
base = f"http://{host}:{port}"
try:
yield base, recorder
finally:
httpd.shutdown()
httpd.server_close()
thread.join(timeout=5)


def _ok(handler: BaseHTTPRequestHandler) -> None:
handler.send_response(200)
handler.send_header("Content-Type", "application/json")
handler.send_header("Content-Length", "2")
handler.end_headers()
handler.wfile.write(b"{}")


def _redirect(status: int, location: str) -> Callable[[BaseHTTPRequestHandler], None]:
def respond(handler: BaseHTTPRequestHandler) -> None:
handler.send_response(status)
handler.send_header("Location", location)
handler.send_header("Content-Length", "0")
handler.end_headers()

return respond


def test_post_preserves_body_across_302(server):
"""A POST that 302s to a new path re-POSTs the body (libcurl POSTREDIR)."""
base, recorder = server
hops = iter([_redirect(302, "/next"), _ok])
recorder.responder = lambda h: next(hops)(h)

session = Session()
response = session.post(f"{base}/start", data=b"payload-bytes", preserve_post=True)

assert response.status_code == 200
assert len(recorder.requests) == 2
assert recorder.requests[0][0] == "POST"
assert recorder.requests[0][3] == b"payload-bytes"
assert recorder.requests[1][0] == "POST"
assert recorder.requests[1][3] == b"payload-bytes"
# response.history should carry the 302 hop.
assert [r.status_code for r in response.history] == [302]


def test_post_downgrades_to_get_when_preserve_post_false(server):
base, recorder = server
hops = iter([_redirect(302, "/next"), _ok])
recorder.responder = lambda h: next(hops)(h)

session = Session()
response = session.post(f"{base}/start", data=b"payload", preserve_post=False)

assert response.status_code == 200
assert recorder.requests[0][0] == "POST"
assert recorder.requests[1][0] == "GET"
assert recorder.requests[1][3] == b"" # body dropped on downgrade


@pytest.mark.parametrize("status", [307, 308])
def test_post_preserves_method_on_307_308(server, status):
base, recorder = server
hops = iter([_redirect(status, "/next"), _ok])
recorder.responder = lambda h: next(hops)(h)

session = Session()
response = session.post(f"{base}/start", data=b"payload", preserve_post=False)

assert response.status_code == 200
assert recorder.requests[1][0] == "POST"
assert recorder.requests[1][3] == b"payload"


def test_cross_origin_redirect_strips_authorization(server):
"""A cross-origin redirect must not forward the Authorization header.

We use ``127.0.0.1`` as the request origin and redirect to ``localhost``
on the same port — different hostnames, so ``rebuild_auth`` must strip
the session-level Authorization header on the follow-up hop.
"""
base, recorder = server
_, port = base.rsplit(":", 1)
cross_origin_next = f"http://localhost:{port}/collect"

hops = iter([_redirect(302, cross_origin_next), _ok])
recorder.responder = lambda h: next(hops)(h)

session = Session()
session.headers["Authorization"] = "Key super-secret-token"
session.post(f"{base}/start", data=b"payload", preserve_post=True)

assert len(recorder.requests) == 2
assert recorder.requests[0][2].get("Authorization") == "Key super-secret-token"
# The cross-origin second hop must NOT carry the Authorization header.
assert "Authorization" not in recorder.requests[1][2]
assert "authorization" not in {k.lower() for k in recorder.requests[1][2]}


def test_same_origin_redirect_preserves_authorization(server):
base, recorder = server
hops = iter([_redirect(302, "/next"), _ok])
recorder.responder = lambda h: next(hops)(h)

session = Session()
session.headers["Authorization"] = "Key super-secret-token"
session.post(f"{base}/start", data=b"payload", preserve_post=True)

assert len(recorder.requests) == 2
assert recorder.requests[0][2].get("Authorization") == "Key super-secret-token"
assert recorder.requests[1][2].get("Authorization") == "Key super-secret-token"


def test_max_redirects_raises_too_many_redirects(server):
base, recorder = server
recorder.responder = _redirect(302, "/loop")

session = Session()
with pytest.raises(requests.exceptions.TooManyRedirects):
session.post(f"{base}/start", data=b"payload", max_redirects=3, preserve_post=True)

# Initial POST + 3 followed redirects before requests raises.
assert len(recorder.requests) == 4
125 changes: 44 additions & 81 deletions src/posit/connect/sessions.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,41 @@
from urllib.parse import urljoin

import requests


class Session(requests.Session):
"""Custom session that implements CURLOPT_POSTREDIR.

This class mimics the functionality of CURLOPT_POSTREDIR from libcurl by
providing a custom implementation of the POST method. It allows the caller
to control whether the original POST data is preserved on redirects or if the
request should be converted to a GET when a redirect occurs. This is achieved
by disabling automatic redirect handling and manually following the redirect
chain with the desired behavior.

Notes
-----
The custom `post` method in this class:

- Disables automatic redirect handling by setting ``allow_redirects=False``.
- Manually follows redirects up to a specified ``max_redirects``.
- Determines the HTTP method for subsequent requests based on the response
status code and the ``preserve_post`` flag:

- For HTTP status codes 307 and 308, the method and request body are
always preserved as POST.
- For other redirects (e.g., 301, 302, 303), the behavior is determined
by ``preserve_post``:
- If ``preserve_post=True``, the POST method is maintained.
- If ``preserve_post=False``, the method is converted to GET and the
request body is discarded.
"""Custom session that preserves POST bodies across 301/302 redirects.

RFC 7231 allows clients to downgrade POST to GET on 301/302, and
``requests`` does so by default. Some Connect endpoints issue a 302 and
expect the client to re-POST the original body to the new location
(mirroring libcurl's ``CURLOPT_POSTREDIR`` behavior). Overriding
:meth:`rebuild_method` keeps every other piece of the ``requests``
redirect machinery intact — ``rebuild_auth`` (which strips the
``Authorization`` header on cross-origin hops), cookie propagation,
``response.history``, ``TooManyRedirects``, proxy rebuild, and streaming
semantics — so this class is a minimal, safe override rather than a
hand-rolled redirect loop.
"""

Examples
--------
Create a session and send a POST request while preserving POST data on redirects:
def __init__(self):
super().__init__()
self._preserve_post_on_redirect = True

>>> session = Session()
>>> response = session.post(
... "https://example.com/api", data={"key": "value"}, preserve_post=True
... )
>>> print(response.status_code)
def rebuild_method(self, prepared_request, response):
"""Preserve POST across 301/302 when ``preserve_post_on_redirect`` is set.

See Also
--------
requests.Session : The base session class from the requests library.
"""
303 always downgrades to GET per the HTTP spec. 307/308 already
preserve the method in the base implementation.
"""
if (
self._preserve_post_on_redirect
and prepared_request.method == "POST"
and response.status_code in (301, 302)
):
return
super().rebuild_method(prepared_request, response)

def post(self, url, data=None, json=None, preserve_post=True, max_redirects=5, **kwargs):
"""
Send a POST request and handle redirects manually.
"""Send a POST request.

Parameters
----------
Expand All @@ -58,46 +46,21 @@ def post(self, url, data=None, json=None, preserve_post=True, max_redirects=5, *
json : any, optional
The JSON data to send.
preserve_post : bool, optional
If True, re-send POST data on redirects (mimicking CURLOPT_POSTREDIR);
if False, converts to GET on 301/302/303 responses.
If True (default), re-send POST data on 301/302 redirects
(mimicking ``CURLOPT_POSTREDIR``). If False, fall back to the
default ``requests`` behavior (downgrade to GET on 301/302).
max_redirects : int, optional
Maximum number of redirects to follow.
Maximum number of redirects to follow before raising
:class:`requests.exceptions.TooManyRedirects`.
**kwargs
Additional keyword arguments passed to the request.

Returns
-------
requests.Response
The final response after following redirects.
Additional keyword arguments passed to :meth:`requests.Session.request`.
"""
# Force manual redirect handling by disabling auto redirects.
kwargs["allow_redirects"] = False

# Initial POST request
response = super().post(url, data=data, json=json, **kwargs)
redirect_count = 0

# Manually follow redirects, if any
while response.is_redirect and redirect_count < max_redirects:
redirect_url = response.headers.get("location")
if not redirect_url:
break # No redirect URL; exit loop

redirect_url = urljoin(response.url, redirect_url)

# For 307 and 308 the HTTP spec mandates preserving the method and body.
if response.status_code in (307, 308):
method = "POST"
else:
if preserve_post:
method = "POST"
else:
method = "GET"
data = None
json = None

# Perform the next request in the redirect chain.
response = self.request(method, redirect_url, data=data, json=json, **kwargs)
redirect_count += 1

return response
previous_preserve = self._preserve_post_on_redirect
previous_max = self.max_redirects
self._preserve_post_on_redirect = preserve_post
self.max_redirects = max_redirects
try:
return super().post(url, data=data, json=json, **kwargs)
finally:
self._preserve_post_on_redirect = previous_preserve
self.max_redirects = previous_max
Loading
Loading