Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 89 additions & 9 deletions backend/cmd/api/internal/auth/middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package auth

import (
"context"
"encoding/json"
"errors"
"log"
"net"
"net/http"
Expand All @@ -11,6 +14,41 @@ import (
"github.com/CMS-Enterprise/ztmf/backend/internal/model"
)

// Error codes returned in the JSON body alongside the HTTP status. The FE keys
// off these to render distinguishable copy: UNAUTHORIZED maps to "your session
// has expired", ACCOUNT_NOT_PROVISIONED maps to a terminal "contact your
// administrator" message with no retry CTA. See ztmf-ui#403.
const (
CodeUnauthorized = "UNAUTHORIZED"
CodeForbiddenOrigin = "FORBIDDEN_ORIGIN"
CodeAccountNotProvisioned = "ACCOUNT_NOT_PROVISIONED"
)

// Package-level seams over the model lookups so tests can stub them without a
// database. Production wiring is the real model functions.
var (
findUserByID = model.FindUserByID
findUserByEmail = model.FindUserByEmail
)

// errorBody is the JSON shape returned on every middleware-rejected request.
// Single shape across 401/403/500 so the FE interceptor can rely on it and
// branch on `code` rather than parsing status alone.
type errorBody struct {
Error string `json:"error"`
Code string `json:"code,omitempty"`
}

// writeJSONError writes a standardized JSON error response and is the only
// rejection surface used by Middleware. Centralizing the shape keeps the FE
// interceptor's contract single-sourced.
func writeJSONError(w http.ResponseWriter, status int, msg, code string) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(errorBody{Error: msg, Code: code})
}

// Middleware authenticates /api/* requests and attaches the matching user to
// the request context. It accepts two token sources, in order:
//
Expand All @@ -21,13 +59,20 @@ import (
// (HS256 bearer) and the E2E suite working, and also covers the interim
// period before the ALB rule flips, where the ALB still injects the IdP
// token on /api/*.
//
// Rejection statuses distinguish three failure shapes the FE needs to
// disambiguate (ztmf-ui#403): 401 for missing/invalid session, 403 with code
// ACCOUNT_NOT_PROVISIONED for an authenticated identity with no app account
// (or a soft-deleted one), and 403 with code FORBIDDEN_ORIGIN for CSRF.
func Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cfg := config.GetInstance()

claims, isSession, ok := claimsFromRequest(r)
if !ok {
http.Error(w, "unauthorized", http.StatusUnauthorized)
writeJSONError(w, http.StatusUnauthorized,
"Your session has expired. Please sign in again.",
CodeUnauthorized)
return
}

Expand All @@ -37,7 +82,9 @@ func Middleware(next http.Handler) http.Handler {
// cookie path is browser-driven; the bearer path is for API clients and
// is not subject to CSRF.
if isSession && !isSafeMethod(r.Method) && !sameOrigin(r) {
http.Error(w, "forbidden", http.StatusForbidden)
writeJSONError(w, http.StatusForbidden,
"Request blocked: origin not allowed.",
CodeForbiddenOrigin)
return
}

Expand All @@ -51,12 +98,16 @@ func Middleware(next http.Handler) http.Handler {
err error
)
if isSession {
user, err = model.FindUserByID(r.Context(), claims.Subject)
user, err = findUserByID(r.Context(), claims.Subject)
} else {
user, err = model.FindUserByEmail(r.Context(), IdentifierFromClaims(claims))
user, err = findUserByEmail(r.Context(), IdentifierFromClaims(claims))
}

if err != nil && !isSession && cfg.IsLocal() {
// Local dev convenience: an unauthenticated identity that doesn't
// map to a row gets a fresh OWNER user so contributors can poke
// around without seeding by hand. Any lookup error (not-found,
// connection blip) routes through this path locally.
log.Printf("Local dev: auto-creating OWNER user for %s\n", claims.Email)
user = &model.User{
Email: claims.Email,
Expand All @@ -69,18 +120,39 @@ func Middleware(next http.Handler) http.Handler {
user, err = user.Save(r.Context())
if err != nil {
log.Printf("Failed to auto-create user: %s\n", err)
http.Error(w, "unauthorized", http.StatusUnauthorized)
writeJSONError(w, http.StatusInternalServerError,
"internal error", "")
return
}
} else if errors.Is(err, model.ErrNoData) {
// The IdP authenticated this identity, but it has no row in the
// ZTMF users table. Distinct from "session expired" - the session
// is valid; the user simply has no app account. The FE branches on
// this code to render a terminal "contact your administrator"
// message instead of looping the user back through the IdP.
log.Printf("authenticated identity has no ZTMF account: %s\n", IdentifierFromClaims(claims))
writeJSONError(w, http.StatusForbidden,
"Your ZTMF account is not set up. Contact your administrator to request access.",
CodeAccountNotProvisioned)
return
} else if err != nil {
log.Printf("Could not find user for request: %s\n", err)
http.Error(w, "unauthorized", http.StatusUnauthorized)
// DB connection blip, decode failure, etc. Not a credential
// problem, so do not present as one to the FE.
log.Printf("user lookup failed: %s\n", err)
writeJSONError(w, http.StatusInternalServerError,
"internal error", "")
return
}

if user.Deleted {
log.Println("a deleted user tried to access the API")
http.Error(w, "unauthorized", http.StatusUnauthorized)
// Same FE-facing UX as the never-provisioned case: the IdP
// session is valid but no usable app account exists. Logged
// distinctly so support can tell "offboarded" from "never
// onboarded" without grepping the users table.
log.Printf("deleted user attempted to access the API: %s\n", user.Email)
writeJSONError(w, http.StatusForbidden,
"Your ZTMF account is no longer active. Contact your administrator.",
CodeAccountNotProvisioned)
return
}

Expand All @@ -89,6 +161,14 @@ func Middleware(next http.Handler) http.Handler {
})
}

// Compile-time assertion that the model lookup vars have the signatures the
// middleware (and the tests) expect. Keeps a future signature drift in the
// model package from sneaking through.
var (
_ func(context.Context, string) (*model.User, error) = findUserByID
_ func(context.Context, string) (*model.User, error) = findUserByEmail
)

// isSafeMethod reports whether the HTTP method is read-only and therefore not
// a CSRF concern.
func isSafeMethod(method string) bool {
Expand Down
160 changes: 160 additions & 0 deletions backend/cmd/api/internal/auth/middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package auth

import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -76,6 +79,163 @@ func TestIsSafeMethod(t *testing.T) {
}
}

// TestMiddleware covers the three response shapes the FE keys off after
// ztmf-ui#403: unauthenticated -> 401 UNAUTHORIZED, authenticated identity
// with no app account -> 403 ACCOUNT_NOT_PROVISIONED, and the happy path
// where a provisioned user passes through to the next handler. A bonus case
// covers the soft-deleted user, which collapses into the same FE UX as
// "never provisioned" but logs distinctly.
func TestMiddleware(t *testing.T) {
cfg := config.GetInstance()

mintBearer := func(t *testing.T, email string) string {
t.Helper()
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, &Claims{
Email: email,
RegisteredClaims: jwt.RegisteredClaims{ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour))},
})
s, err := tok.SignedString([]byte(testHS256Secret))
require.NoError(t, err)
return s
}

nextFn := func(called *bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
*called = true
w.WriteHeader(http.StatusOK)
})
}

decodeBody := func(t *testing.T, w *httptest.ResponseRecorder) errorBody {
t.Helper()
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
var body errorBody
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
return body
}

// stubFindUserByEmail swaps the package-level seam for the duration of a
// subtest. The original is restored via t.Cleanup so subtests stay
// independent regardless of order.
stubFindUserByEmail := func(t *testing.T, fn func(context.Context, string) (*model.User, error)) {
t.Helper()
prev := findUserByEmail
findUserByEmail = fn
t.Cleanup(func() { findUserByEmail = prev })
}

t.Run("no auth -> 401 UNAUTHORIZED", func(t *testing.T) {
var called bool
r := httptest.NewRequest(http.MethodGet, "/api/v1/users/current", nil)
w := httptest.NewRecorder()

Middleware(nextFn(&called)).ServeHTTP(w, r)

assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.False(t, called, "next must not be called on rejection")
body := decodeBody(t, w)
assert.Equal(t, CodeUnauthorized, body.Code)
assert.NotEmpty(t, body.Error)
})

t.Run("bearer present but token invalid -> 401 UNAUTHORIZED", func(t *testing.T) {
var called bool
r := httptest.NewRequest(http.MethodGet, "/api/v1/users/current", nil)
r.Header.Set(cfg.Auth.HeaderField, "Bearer not.a.real.token")
w := httptest.NewRecorder()

Middleware(nextFn(&called)).ServeHTTP(w, r)

assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.False(t, called)
body := decodeBody(t, w)
assert.Equal(t, CodeUnauthorized, body.Code)
})

t.Run("authed + unprovisioned -> 403 ACCOUNT_NOT_PROVISIONED", func(t *testing.T) {
stubFindUserByEmail(t, func(context.Context, string) (*model.User, error) {
return nil, model.ErrNoData
})

var called bool
r := httptest.NewRequest(http.MethodGet, "/api/v1/users/current", nil)
r.Header.Set(cfg.Auth.HeaderField, "Bearer "+mintBearer(t, "ghost@nowhere.xyz"))
w := httptest.NewRecorder()

Middleware(nextFn(&called)).ServeHTTP(w, r)

assert.Equal(t, http.StatusForbidden, w.Code)
assert.False(t, called, "next must not be called for an unprovisioned identity")
body := decodeBody(t, w)
assert.Equal(t, CodeAccountNotProvisioned, body.Code)
assert.NotEmpty(t, body.Error)
})

t.Run("authed + provisioned -> next called", func(t *testing.T) {
stubFindUserByEmail(t, func(_ context.Context, email string) (*model.User, error) {
return &model.User{
UserID: "11111111-1111-1111-1111-111111111111",
Email: email,
Role: "OWNER",
}, nil
})

var called bool
r := httptest.NewRequest(http.MethodGet, "/api/v1/users/current", nil)
r.Header.Set(cfg.Auth.HeaderField, "Bearer "+mintBearer(t, "provisioned@empire.test"))
w := httptest.NewRecorder()

Middleware(nextFn(&called)).ServeHTTP(w, r)

assert.Equal(t, http.StatusOK, w.Code)
assert.True(t, called, "next must be called on the happy path")
})

t.Run("authed + soft-deleted -> 403 ACCOUNT_NOT_PROVISIONED", func(t *testing.T) {
stubFindUserByEmail(t, func(_ context.Context, email string) (*model.User, error) {
return &model.User{
UserID: "22222222-2222-2222-2222-222222222222",
Email: email,
Role: "OWNER",
Deleted: true,
}, nil
})

var called bool
r := httptest.NewRequest(http.MethodGet, "/api/v1/users/current", nil)
r.Header.Set(cfg.Auth.HeaderField, "Bearer "+mintBearer(t, "offboarded@empire.test"))
w := httptest.NewRecorder()

Middleware(nextFn(&called)).ServeHTTP(w, r)

assert.Equal(t, http.StatusForbidden, w.Code)
assert.False(t, called)
body := decodeBody(t, w)
assert.Equal(t, CodeAccountNotProvisioned, body.Code)
})

t.Run("authed + lookup errors (non-ErrNoData) -> 500", func(t *testing.T) {
stubFindUserByEmail(t, func(context.Context, string) (*model.User, error) {
return nil, errors.New("simulated db connection blip")
})

var called bool
r := httptest.NewRequest(http.MethodGet, "/api/v1/users/current", nil)
r.Header.Set(cfg.Auth.HeaderField, "Bearer "+mintBearer(t, "anyone@empire.test"))
w := httptest.NewRecorder()

Middleware(nextFn(&called)).ServeHTTP(w, r)

assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.False(t, called, "next must not be called on an upstream failure")
// 500 carries an error but no code: opaque to the FE on purpose so the
// "contact your administrator" terminal copy is not triggered by a
// transient DB blip.
body := decodeBody(t, w)
assert.Empty(t, body.Code)
})
}

func TestSameOrigin(t *testing.T) {
// CookieDomain is unset in the test env, so sameOrigin falls back to the
// request Host.
Expand Down