Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions backend/cmd/api/internal/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"log"
"math/big"
"net/http"
"regexp"
"slices"
"sync"
"time"
Expand Down Expand Up @@ -46,6 +47,28 @@ var (
// new kid is an acceptable trade for closing that amplification vector.
const minEntraRefreshInterval = 5 * time.Minute

// kidPattern restricts an attacker-controlled JWT `kid` to URL-safe characters
// before it is used to build the Okta key-fetch URL. `kid` is read from the
// unverified token header (it selects the verification key, so it must be read
// before the signature is checked), making it untrusted input. Okta key ids are
// URL-safe base64 thumbprints, so this allowlist rejects nothing legitimate
// while preventing path traversal, host injection, and scheme tricks from
// reshaping the outbound request (SSRF / key-confusion).
var kidPattern = regexp.MustCompile(`^[A-Za-z0-9_-]{1,200}$`)

// keyFetchClient is used for all outbound IdP key fetches (the Okta per-kid PEM
// endpoint and the Entra JWKS document). It sets a timeout so a slow or hung key
// endpoint cannot pin a request goroutine, and it refuses to follow redirects so
// a 3xx response cannot bounce the fetch to an unexpected host - defense in depth
// alongside kid validation. http.DefaultClient (no timeout, follows redirects)
// must not be used for these.
var keyFetchClient = &http.Client{
Timeout: 10 * time.Second,
CheckRedirect: func(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
},
}

var (
ErrUntrustedIssuer = errors.New("token issuer is not trusted")
ErrWrongTenant = errors.New("token tenant is not the trusted Entra tenant")
Expand Down Expand Up @@ -180,6 +203,12 @@ func oktaKey(token *jwt.Token) (interface{}, error) {
if !ok {
return nil, errors.New("token missing kid")
}
// Validate before anything else: kid is attacker-controlled and is about to
// be concatenated into an outbound URL. Restricting it to URL-safe characters
// is what makes the concatenation below safe (no traversal/host injection).
if !kidPattern.MatchString(kid) {
return nil, errors.New("invalid kid")
}

keysMu.RLock()
cached, ok := keys[kid]
Expand All @@ -190,9 +219,12 @@ func oktaKey(token *jwt.Token) (interface{}, error) {

cfg := config.GetInstance()
url := cfg.Auth.TokenKeyUrl + kid
req, _ := http.NewRequest("GET", url, nil)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}

res, err := http.DefaultClient.Do(req)
res, err := keyFetchClient.Do(req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -273,7 +305,7 @@ func refreshEntraKeys() error {
return errors.New("entra JWKS URL not configured")
}

res, err := http.DefaultClient.Get(cfg.Auth.EntraJWKSUrl)
res, err := keyFetchClient.Get(cfg.Auth.EntraJWKSUrl)
if err != nil {
return err
}
Expand Down
102 changes: 102 additions & 0 deletions backend/cmd/api/internal/auth/token_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package auth

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"math/big"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

"github.com/CMS-Enterprise/ztmf/backend/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -159,3 +167,97 @@ func TestDecodeJWT_HS256(t *testing.T) {
assert.Error(t, err)
})
}

// TestKidPattern verifies the allowlist that gates the attacker-controlled kid
// before it is used to build the Okta key-fetch URL. Legitimate Okta key ids
// (URL-safe base64 / UUID-style) pass; traversal, host-injection, scheme, and
// oversized values are rejected.
func TestKidPattern(t *testing.T) {
valid := []string{
"abc123",
"AbC_123-xyz",
"0J8e2cF3aB4dQ5sZ_w-T",
"550e8400e29b41d4a716446655440000",
strings.Repeat("a", 200), // length boundary: 200 is the inclusive max
}
invalid := []string{
"",
"../../../../latest/meta-data", // path traversal
"a/b", // path separator
"key.pem", // dot (traversal building block)
"@attacker.example", // userinfo/host injection
"http://evil.example/k", // scheme injection
"%2e%2e%2f", // url-encoded traversal
"a b", // whitespace
"a:b", // colon
strings.Repeat("a", 201), // length boundary: 201 is one over the max
}
for _, k := range valid {
assert.True(t, kidPattern.MatchString(k), "expected valid kid: %q", k)
}
for _, k := range invalid {
assert.False(t, kidPattern.MatchString(k), "expected invalid kid: %q", k)
}
}

// TestOktaKeyRejectsMaliciousKid confirms oktaKey rejects a malicious kid with
// the validation error *before* attempting any outbound fetch (a network attempt
// would surface a different error). This is the SSRF / key-confusion guard.
func TestOktaKeyRejectsMaliciousKid(t *testing.T) {
for _, kid := range []string{
"../../../../latest/meta-data/iam/security-credentials/",
"@attacker.example/key.pem",
"http://evil.example/key.pem",
"a/b",
strings.Repeat("a", 500),
} {
tok := &jwt.Token{Header: map[string]any{"alg": "ES256", "kid": kid}}
_, err := oktaKey(tok)
require.Error(t, err)
assert.Equal(t, "invalid kid", err.Error(), "kid %q must be rejected before any fetch", kid)
}
}

// TestOktaKeyMissingKid covers the no-kid header case.
func TestOktaKeyMissingKid(t *testing.T) {
tok := &jwt.Token{Header: map[string]any{"alg": "ES256"}}
_, err := oktaKey(tok)
require.Error(t, err)
assert.Equal(t, "token missing kid", err.Error())
}

// TestOktaKeyValidKidFetchesAndParses is the positive regression guard: a valid
// kid builds exactly TokenKeyUrl+kid (no surprise path segments), the fetch is
// served and the PEM parsed into an ECDSA key. Locks in that the hardened client
// + concatenation behave as before for legitimate input.
func TestOktaKeyValidKidFetchesAndParses(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
der, err := x509.MarshalPKIXPublicKey(&priv.PublicKey)
require.NoError(t, err)
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: der})

var gotPath string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
_, _ = w.Write(pemBytes)
}))
defer ts.Close()

cfg := config.GetInstance()
orig := cfg.Auth.TokenKeyUrl
cfg.Auth.TokenKeyUrl = ts.URL + "/"
defer func() { cfg.Auth.TokenKeyUrl = orig }()

const kid = "valid-Kid_123"
keysMu.Lock()
delete(keys, kid)
keysMu.Unlock()

tok := &jwt.Token{Header: map[string]any{"alg": "ES256", "kid": kid}}
key, err := oktaKey(tok)
require.NoError(t, err)
assert.Equal(t, "/"+kid, gotPath, "fetch path is exactly TokenKeyUrl+kid")
_, ok := key.(*ecdsa.PublicKey)
assert.True(t, ok, "returns a parsed ECDSA public key")
}
Loading