diff --git a/backend/cmd/api/internal/auth/token.go b/backend/cmd/api/internal/auth/token.go index fc7c209..7de65c3 100644 --- a/backend/cmd/api/internal/auth/token.go +++ b/backend/cmd/api/internal/auth/token.go @@ -12,6 +12,7 @@ import ( "log" "math/big" "net/http" + "regexp" "slices" "sync" "time" @@ -46,6 +47,28 @@ var ( // new kid is an acceptable trade for closing that amplification vector. const minEntraRefreshInterval = 5 * time.Minute +// kidPattern restricts an attacker-controlled JWT `kid` to URL-safe characters +// before it is used to build the Okta key-fetch URL. `kid` is read from the +// unverified token header (it selects the verification key, so it must be read +// before the signature is checked), making it untrusted input. Okta key ids are +// URL-safe base64 thumbprints, so this allowlist rejects nothing legitimate +// while preventing path traversal, host injection, and scheme tricks from +// reshaping the outbound request (SSRF / key-confusion). +var kidPattern = regexp.MustCompile(`^[A-Za-z0-9_-]{1,200}$`) + +// keyFetchClient is used for all outbound IdP key fetches (the Okta per-kid PEM +// endpoint and the Entra JWKS document). It sets a timeout so a slow or hung key +// endpoint cannot pin a request goroutine, and it refuses to follow redirects so +// a 3xx response cannot bounce the fetch to an unexpected host - defense in depth +// alongside kid validation. http.DefaultClient (no timeout, follows redirects) +// must not be used for these. +var keyFetchClient = &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, +} + var ( ErrUntrustedIssuer = errors.New("token issuer is not trusted") ErrWrongTenant = errors.New("token tenant is not the trusted Entra tenant") @@ -180,6 +203,12 @@ func oktaKey(token *jwt.Token) (interface{}, error) { if !ok { return nil, errors.New("token missing kid") } + // Validate before anything else: kid is attacker-controlled and is about to + // be concatenated into an outbound URL. Restricting it to URL-safe characters + // is what makes the concatenation below safe (no traversal/host injection). + if !kidPattern.MatchString(kid) { + return nil, errors.New("invalid kid") + } keysMu.RLock() cached, ok := keys[kid] @@ -190,9 +219,12 @@ func oktaKey(token *jwt.Token) (interface{}, error) { cfg := config.GetInstance() url := cfg.Auth.TokenKeyUrl + kid - req, _ := http.NewRequest("GET", url, nil) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } - res, err := http.DefaultClient.Do(req) + res, err := keyFetchClient.Do(req) if err != nil { return nil, err } @@ -273,7 +305,7 @@ func refreshEntraKeys() error { return errors.New("entra JWKS URL not configured") } - res, err := http.DefaultClient.Get(cfg.Auth.EntraJWKSUrl) + res, err := keyFetchClient.Get(cfg.Auth.EntraJWKSUrl) if err != nil { return err } diff --git a/backend/cmd/api/internal/auth/token_test.go b/backend/cmd/api/internal/auth/token_test.go index 05d943c..3c63526 100644 --- a/backend/cmd/api/internal/auth/token_test.go +++ b/backend/cmd/api/internal/auth/token_test.go @@ -1,14 +1,22 @@ package auth import ( + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/rsa" + "crypto/x509" "encoding/base64" + "encoding/pem" "math/big" + "net/http" + "net/http/httptest" "os" + "strings" "testing" "time" + "github.com/CMS-Enterprise/ztmf/backend/internal/config" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -159,3 +167,97 @@ func TestDecodeJWT_HS256(t *testing.T) { assert.Error(t, err) }) } + +// TestKidPattern verifies the allowlist that gates the attacker-controlled kid +// before it is used to build the Okta key-fetch URL. Legitimate Okta key ids +// (URL-safe base64 / UUID-style) pass; traversal, host-injection, scheme, and +// oversized values are rejected. +func TestKidPattern(t *testing.T) { + valid := []string{ + "abc123", + "AbC_123-xyz", + "0J8e2cF3aB4dQ5sZ_w-T", + "550e8400e29b41d4a716446655440000", + strings.Repeat("a", 200), // length boundary: 200 is the inclusive max + } + invalid := []string{ + "", + "../../../../latest/meta-data", // path traversal + "a/b", // path separator + "key.pem", // dot (traversal building block) + "@attacker.example", // userinfo/host injection + "http://evil.example/k", // scheme injection + "%2e%2e%2f", // url-encoded traversal + "a b", // whitespace + "a:b", // colon + strings.Repeat("a", 201), // length boundary: 201 is one over the max + } + for _, k := range valid { + assert.True(t, kidPattern.MatchString(k), "expected valid kid: %q", k) + } + for _, k := range invalid { + assert.False(t, kidPattern.MatchString(k), "expected invalid kid: %q", k) + } +} + +// TestOktaKeyRejectsMaliciousKid confirms oktaKey rejects a malicious kid with +// the validation error *before* attempting any outbound fetch (a network attempt +// would surface a different error). This is the SSRF / key-confusion guard. +func TestOktaKeyRejectsMaliciousKid(t *testing.T) { + for _, kid := range []string{ + "../../../../latest/meta-data/iam/security-credentials/", + "@attacker.example/key.pem", + "http://evil.example/key.pem", + "a/b", + strings.Repeat("a", 500), + } { + tok := &jwt.Token{Header: map[string]any{"alg": "ES256", "kid": kid}} + _, err := oktaKey(tok) + require.Error(t, err) + assert.Equal(t, "invalid kid", err.Error(), "kid %q must be rejected before any fetch", kid) + } +} + +// TestOktaKeyMissingKid covers the no-kid header case. +func TestOktaKeyMissingKid(t *testing.T) { + tok := &jwt.Token{Header: map[string]any{"alg": "ES256"}} + _, err := oktaKey(tok) + require.Error(t, err) + assert.Equal(t, "token missing kid", err.Error()) +} + +// TestOktaKeyValidKidFetchesAndParses is the positive regression guard: a valid +// kid builds exactly TokenKeyUrl+kid (no surprise path segments), the fetch is +// served and the PEM parsed into an ECDSA key. Locks in that the hardened client +// + concatenation behave as before for legitimate input. +func TestOktaKeyValidKidFetchesAndParses(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + der, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + require.NoError(t, err) + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: der}) + + var gotPath string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + _, _ = w.Write(pemBytes) + })) + defer ts.Close() + + cfg := config.GetInstance() + orig := cfg.Auth.TokenKeyUrl + cfg.Auth.TokenKeyUrl = ts.URL + "/" + defer func() { cfg.Auth.TokenKeyUrl = orig }() + + const kid = "valid-Kid_123" + keysMu.Lock() + delete(keys, kid) + keysMu.Unlock() + + tok := &jwt.Token{Header: map[string]any{"alg": "ES256", "kid": kid}} + key, err := oktaKey(tok) + require.NoError(t, err) + assert.Equal(t, "/"+kid, gotPath, "fetch path is exactly TokenKeyUrl+kid") + _, ok := key.(*ecdsa.PublicKey) + assert.True(t, ok, "returns a parsed ECDSA public key") +}