From e2d828da99b6b786340303524bb2c78972838073 Mon Sep 17 00:00:00 2001 From: Tharun N V Date: Sat, 20 Jun 2026 16:06:25 +0530 Subject: [PATCH 1/4] feat: implement JWKS endpoint (#171) --- internal/config/config.go | 58 +++++++++++++++++++------- internal/dto/jwks.go | 16 +++++++ internal/handler/jwks_handler.go | 56 +++++++++++++++++++++++++ internal/routes/routes.go | 6 +++ internal/service/token_service.go | 27 ++++++------ internal/service/token_service_test.go | 19 +++++++-- 6 files changed, 151 insertions(+), 31 deletions(-) create mode 100644 internal/dto/jwks.go create mode 100644 internal/handler/jwks_handler.go diff --git a/internal/config/config.go b/internal/config/config.go index b1f5eab..98dd5d6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,10 +1,13 @@ package config import ( + "crypto/rand" + "crypto/rsa" "log" "os" "strconv" + "github.com/golang-jwt/jwt/v5" "github.com/joho/godotenv" ) @@ -36,11 +39,12 @@ type RedisConfig struct { } type JWTConfig struct { - AccessSecret string - RefreshSecret string - AccessExpiry string - RefreshExpiry string - RefreshGracePeriod string + PrivateKey *rsa.PrivateKey + PublicKey *rsa.PublicKey + KeyID string + AccessExpiry string + RefreshExpiry string + RefreshGracePeriod string } type OAuthConfig struct { Google GoogleOAuthConfig @@ -112,14 +116,8 @@ func LoadConfig() *Config { appURL := getEnv("APP_URL", "http://localhost:3000") - accessSecret := getEnv("JWT_SECRET", "") - refreshSecret := getEnv("JWT_REFRESH_SECRET", "") - if len(accessSecret) < 32 { - log.Fatal("JWT_SECRET must be set and at least 32 bytes long") - } - if len(refreshSecret) < 32 { - log.Fatal("JWT_REFRESH_SECRET must be set and at least 32 bytes long") - } + privKey, pubKey := loadRSAKeys() + keyID := getEnv("JWT_KEY_ID", "default-key-1") encKey := getEnv("ENCRYPTION_KEY", "") if encKey == "" || encKey == "0123456789abcdef0123456789abcdef" { @@ -142,8 +140,9 @@ func LoadConfig() *Config { TTL: redisTTL, }, JWT: JWTConfig{ - AccessSecret: getEnv("JWT_SECRET", ""), - RefreshSecret: getEnv("JWT_REFRESH_SECRET", ""), + PrivateKey: privKey, + PublicKey: pubKey, + KeyID: keyID, AccessExpiry: getEnv("JWT_ACCESS_EXPIRY", "15m"), RefreshExpiry: getEnv("JWT_REFRESH_EXPIRY", "168h"), RefreshGracePeriod: getEnv("JWT_REFRESH_GRACE_PERIOD", "10s"), @@ -187,3 +186,32 @@ func getEnv(key, defaultValue string) string { } return defaultValue } + +func loadRSAKeys() (*rsa.PrivateKey, *rsa.PublicKey) { + privPath := getEnv("JWT_PRIVATE_KEY_PATH", "private.pem") + pubPath := getEnv("JWT_PUBLIC_KEY_PATH", "public.pem") + + privBytes, err1 := os.ReadFile(privPath) + pubBytes, err2 := os.ReadFile(pubPath) + + if err1 != nil || err2 != nil { + log.Println("RSA keys not found at provided paths, generating temporary in-memory keys for development/testing...") + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + log.Fatalf("Failed to generate temp RSA key: %v", err) + } + return privKey, &privKey.PublicKey + } + + privKey, err := jwt.ParseRSAPrivateKeyFromPEM(privBytes) + if err != nil { + log.Fatalf("Failed to parse private key: %v", err) + } + + pubKey, err := jwt.ParseRSAPublicKeyFromPEM(pubBytes) + if err != nil { + log.Fatalf("Failed to parse public key: %v", err) + } + + return privKey, pubKey +} diff --git a/internal/dto/jwks.go b/internal/dto/jwks.go new file mode 100644 index 0000000..ee2dc32 --- /dev/null +++ b/internal/dto/jwks.go @@ -0,0 +1,16 @@ +package dto + +// JWK represents a single JSON Web Key +type JWK struct { + Kty string `json:"kty"` // Key Type + Alg string `json:"alg"` // Algorithm + Use string `json:"use"` // Public Key Use (e.g., "sig") + Kid string `json:"kid"` // Key ID + N string `json:"n"` // Modulus (Base64url encoded) + E string `json:"e"` // Exponent (Base64url encoded) +} + +// JWKSResponse represents a JSON Web Key Set +type JWKSResponse struct { + Keys []JWK `json:"keys"` +} diff --git a/internal/handler/jwks_handler.go b/internal/handler/jwks_handler.go new file mode 100644 index 0000000..486d7bd --- /dev/null +++ b/internal/handler/jwks_handler.go @@ -0,0 +1,56 @@ +package handler + +import ( + "encoding/base64" + "math/big" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/roshankumar0036singh/auth-server/internal/config" + "github.com/roshankumar0036singh/auth-server/internal/dto" +) + +type JWKSHandler struct { + cfg *config.Config +} + +func NewJWKSHandler(cfg *config.Config) *JWKSHandler { + return &JWKSHandler{cfg: cfg} +} + +// GetJWKS returns the public keys in JWKS format +// @Summary Get JSON Web Key Set +// @Description Returns the public keys used to verify JWTs issued by this server +// @Tags OpenID Connect +// @Produce json +// @Success 200 {object} dto.JWKSResponse +// @Router /.well-known/jwks.json [get] +func (h *JWKSHandler) GetJWKS(c *gin.Context) { + pubKey := h.cfg.JWT.PublicKey + + if pubKey == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "public key not configured"}) + return + } + + // The modulus N needs to be BigEndian bytes, encoded as Base64Url (without padding) + nBytes := pubKey.N.Bytes() + nStr := base64.RawURLEncoding.EncodeToString(nBytes) + + // Exponent E is an int, need to encode to bytes then Base64Url + eBytes := big.NewInt(int64(pubKey.E)).Bytes() + eStr := base64.RawURLEncoding.EncodeToString(eBytes) + + jwk := dto.JWK{ + Kty: "RSA", + Alg: "RS256", + Use: "sig", + Kid: h.cfg.JWT.KeyID, + N: nStr, + E: eStr, + } + + c.JSON(http.StatusOK, dto.JWKSResponse{ + Keys: []dto.JWK{jwk}, + }) +} diff --git a/internal/routes/routes.go b/internal/routes/routes.go index 4bb32b9..226cbc2 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -71,6 +71,7 @@ func SetupRoutes(router *gin.Engine, db *gorm.DB, redisClient *redis.Client, cfg adminHandler := handler.NewAdminHandler(authService) oauthClientHandler := handler.NewOAuthClientHandler(oauthProviderService) oauthHandler := handler.NewOAuthHandler(oauthProviderService, userRepo) + jwksHandler := handler.NewJWKSHandler(cfg) // Apply global middleware router.Use(middleware.CORSMiddleware(cfg)) @@ -121,6 +122,11 @@ func SetupRoutes(router *gin.Engine, db *gorm.DB, redisClient *redis.Client, cfg }) }) + // JWKS endpoint + router.GET("/.well-known/jwks.json", jwksHandler.GetJWKS) + + + // OAuth 2.0 Provider endpoints router.GET("/oauth/authorize", middleware.OptionalAuthMiddleware(tokenService, cacheService), oauthHandler.Authorize) router.POST("/oauth/authorize", middleware.AuthMiddleware(tokenService, cacheService), oauthHandler.AuthorizePost) diff --git a/internal/service/token_service.go b/internal/service/token_service.go index d66a349..c7d5c07 100644 --- a/internal/service/token_service.go +++ b/internal/service/token_service.go @@ -59,8 +59,9 @@ func (s *TokenService) GenerateAccessToken(user *models.User, sessionID string) }, } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString([]byte(s.cfg.JWT.AccessSecret)) + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = s.cfg.JWT.KeyID + tokenString, err := token.SignedString(s.cfg.JWT.PrivateKey) if err != nil { return "", err } @@ -84,8 +85,9 @@ func (s *TokenService) GenerateRefreshToken(user *models.User) (string, error) { }, } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString([]byte(s.cfg.JWT.RefreshSecret)) + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = s.cfg.JWT.KeyID + tokenString, err := token.SignedString(s.cfg.JWT.PrivateKey) if err != nil { return "", err } @@ -96,10 +98,10 @@ func (s *TokenService) GenerateRefreshToken(user *models.User) (string, error) { // ValidateAccessToken validates and parses an access token func (s *TokenService) ValidateAccessToken(tokenString string) (*JWTClaims, error) { token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, errors.New(errInvalidSignMethod) } - return []byte(s.cfg.JWT.AccessSecret), nil + return s.cfg.JWT.PublicKey, nil }) if err != nil { @@ -133,8 +135,9 @@ func (s *TokenService) GenerateMFAToken(userID string) (string, error) { }, } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString([]byte(s.cfg.JWT.AccessSecret)) + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = s.cfg.JWT.KeyID + return token.SignedString(s.cfg.JWT.PrivateKey) } // ValidateMFAToken validates an MFA-pending token and returns the user ID it @@ -142,10 +145,10 @@ func (s *TokenService) GenerateMFAToken(userID string) (string, error) { // marker, so access/refresh tokens cannot stand in for it. func (s *TokenService) ValidateMFAToken(tokenString string) (string, error) { token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, errors.New(errInvalidSignMethod) } - return []byte(s.cfg.JWT.AccessSecret), nil + return s.cfg.JWT.PublicKey, nil }) if err != nil { return "", err @@ -161,10 +164,10 @@ func (s *TokenService) ValidateMFAToken(tokenString string) (string, error) { // ValidateRefreshToken validates and parses a refresh token func (s *TokenService) ValidateRefreshToken(tokenString string) (*JWTClaims, error) { token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, errors.New(errInvalidSignMethod) } - return []byte(s.cfg.JWT.RefreshSecret), nil + return s.cfg.JWT.PublicKey, nil }) if err != nil { diff --git a/internal/service/token_service_test.go b/internal/service/token_service_test.go index 1a085b0..01650f9 100644 --- a/internal/service/token_service_test.go +++ b/internal/service/token_service_test.go @@ -1,6 +1,8 @@ package service_test import ( + "crypto/rand" + "crypto/rsa" "testing" "github.com/roshankumar0036singh/auth-server/internal/config" @@ -9,11 +11,18 @@ import ( "github.com/stretchr/testify/assert" ) +func getTestRSAKeys() (*rsa.PrivateKey, *rsa.PublicKey) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + return priv, &priv.PublicKey +} + func TestTokenService_GenerateAccessToken(t *testing.T) { + priv, pub := getTestRSAKeys() cfg := &config.Config{ JWT: config.JWTConfig{ - AccessSecret: "test-secret", - RefreshSecret: "test-refresh-secret", + PrivateKey: priv, + PublicKey: pub, + KeyID: "test-key", }, } svc := service.NewTokenService(cfg) @@ -39,10 +48,12 @@ func TestTokenService_GenerateAccessToken(t *testing.T) { } func TestTokenService_GenerateRefreshToken(t *testing.T) { + priv, pub := getTestRSAKeys() cfg := &config.Config{ JWT: config.JWTConfig{ - AccessSecret: "test-secret", - RefreshSecret: "test-refresh-secret", + PrivateKey: priv, + PublicKey: pub, + KeyID: "test-key", }, } svc := service.NewTokenService(cfg) From 0057a66ec4c882e5e7d43adf5d2583e6d6fa0fa6 Mon Sep 17 00:00:00 2001 From: Tharun N V Date: Sat, 20 Jun 2026 16:08:25 +0530 Subject: [PATCH 2/4] feat: implement OIDC discovery endpoint (#170) --- internal/dto/oidc_discovery.go | 15 ++++++ internal/handler/oidc_discovery_handler.go | 58 ++++++++++++++++++++++ internal/routes/routes.go | 4 ++ 3 files changed, 77 insertions(+) create mode 100644 internal/dto/oidc_discovery.go create mode 100644 internal/handler/oidc_discovery_handler.go diff --git a/internal/dto/oidc_discovery.go b/internal/dto/oidc_discovery.go new file mode 100644 index 0000000..0f0474e --- /dev/null +++ b/internal/dto/oidc_discovery.go @@ -0,0 +1,15 @@ +package dto + +// OIDCDiscoveryResponse represents the standard OpenID Connect discovery configuration +type OIDCDiscoveryResponse struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + JwksURI string `json:"jwks_uri"` + ScopesSupported []string `json:"scopes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` +} diff --git a/internal/handler/oidc_discovery_handler.go b/internal/handler/oidc_discovery_handler.go new file mode 100644 index 0000000..e91694d --- /dev/null +++ b/internal/handler/oidc_discovery_handler.go @@ -0,0 +1,58 @@ +package handler + +import ( + "fmt" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/roshankumar0036singh/auth-server/internal/config" + "github.com/roshankumar0036singh/auth-server/internal/dto" +) + +type OIDCDiscoveryHandler struct { + cfg *config.Config +} + +func NewOIDCDiscoveryHandler(cfg *config.Config) *OIDCDiscoveryHandler { + return &OIDCDiscoveryHandler{cfg: cfg} +} + +// GetConfiguration returns the OpenID Connect discovery document +// @Summary Get OpenID Connect Configuration +// @Description Returns the OpenID Connect discovery metadata document +// @Tags OpenID Connect +// @Produce json +// @Success 200 {object} dto.OIDCDiscoveryResponse +// @Router /.well-known/openid-configuration [get] +func (h *OIDCDiscoveryHandler) GetConfiguration(c *gin.Context) { + baseURL := strings.TrimRight(h.cfg.App.URL, "/") + + response := dto.OIDCDiscoveryResponse{ + Issuer: baseURL, + AuthorizationEndpoint: fmt.Sprintf("%s/oauth/authorize", baseURL), + TokenEndpoint: fmt.Sprintf("%s/oauth/token", baseURL), + UserinfoEndpoint: fmt.Sprintf("%s/oauth/userinfo", baseURL), + JwksURI: fmt.Sprintf("%s/.well-known/jwks.json", baseURL), + ScopesSupported: []string{ + "openid", + "profile", + "email", + }, + ResponseTypesSupported: []string{ + "code", + }, + GrantTypesSupported: []string{ + "authorization_code", + "refresh_token", + }, + SubjectTypesSupported: []string{ + "public", + }, + IDTokenSigningAlgValuesSupported: []string{ + "RS256", + }, + } + + c.JSON(http.StatusOK, response) +} diff --git a/internal/routes/routes.go b/internal/routes/routes.go index 226cbc2..fd010cf 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -72,6 +72,7 @@ func SetupRoutes(router *gin.Engine, db *gorm.DB, redisClient *redis.Client, cfg oauthClientHandler := handler.NewOAuthClientHandler(oauthProviderService) oauthHandler := handler.NewOAuthHandler(oauthProviderService, userRepo) jwksHandler := handler.NewJWKSHandler(cfg) + discoveryHandler := handler.NewOIDCDiscoveryHandler(cfg) // Apply global middleware router.Use(middleware.CORSMiddleware(cfg)) @@ -125,6 +126,9 @@ func SetupRoutes(router *gin.Engine, db *gorm.DB, redisClient *redis.Client, cfg // JWKS endpoint router.GET("/.well-known/jwks.json", jwksHandler.GetJWKS) + // OIDC Discovery endpoint + router.GET("/.well-known/openid-configuration", discoveryHandler.GetConfiguration) + // OAuth 2.0 Provider endpoints From 5de8d4726907e72df2a2d495821d170103d325c3 Mon Sep 17 00:00:00 2001 From: Tharun N V Date: Sat, 20 Jun 2026 16:37:04 +0530 Subject: [PATCH 3/4] fix: address CodeRabbit review feedback for JWKS endpoint --- docs/docs.go | 60 +++++++++++++++++++ docs/swagger.json | 60 +++++++++++++++++++ docs/swagger.yaml | 41 +++++++++++++ internal/config/config.go | 4 +- .../handler/auth_handler_protected_test.go | 3 +- internal/handler/auth_handler_test.go | 6 +- internal/handler/jwks_handler.go | 2 +- internal/handler/oauth_handler_test.go | 9 ++- internal/middleware/auth_test.go | 4 +- internal/service/security_phase1_test.go | 9 +-- internal/service/token_service_test.go | 12 +--- internal/testutils/setup.go | 24 +++++++- 12 files changed, 211 insertions(+), 23 deletions(-) diff --git a/docs/docs.go b/docs/docs.go index ee2772d..c77e274 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -23,6 +23,26 @@ const docTemplate = `{ "host": "{{.Host}}", "basePath": "{{.BasePath}}", "paths": { + "/.well-known/jwks.json": { + "get": { + "description": "Returns the public keys used to verify JWTs issued by this server", + "produces": [ + "application/json" + ], + "tags": [ + "OpenID Connect" + ], + "summary": "Get JSON Web Key Set", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse" + } + } + } + } + }, "/api/admin/users": { "get": { "security": [ @@ -1405,6 +1425,46 @@ const docTemplate = `{ } } }, + "github_com_roshankumar0036singh_auth-server_internal_dto.JWK": { + "type": "object", + "properties": { + "alg": { + "description": "Algorithm", + "type": "string" + }, + "e": { + "description": "Exponent (Base64url encoded)", + "type": "string" + }, + "kid": { + "description": "Key ID", + "type": "string" + }, + "kty": { + "description": "Key Type", + "type": "string" + }, + "n": { + "description": "Modulus (Base64url encoded)", + "type": "string" + }, + "use": { + "description": "Public Key Use (e.g., \"sig\")", + "type": "string" + } + } + }, + "github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWK" + } + } + } + }, "github_com_roshankumar0036singh_auth-server_internal_dto.LoginRequest": { "type": "object", "required": [ diff --git a/docs/swagger.json b/docs/swagger.json index 6486388..ef06bae 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -21,6 +21,26 @@ "host": "localhost:8080", "basePath": "/", "paths": { + "/.well-known/jwks.json": { + "get": { + "description": "Returns the public keys used to verify JWTs issued by this server", + "produces": [ + "application/json" + ], + "tags": [ + "OpenID Connect" + ], + "summary": "Get JSON Web Key Set", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse" + } + } + } + } + }, "/api/admin/users": { "get": { "security": [ @@ -1403,6 +1423,46 @@ } } }, + "github_com_roshankumar0036singh_auth-server_internal_dto.JWK": { + "type": "object", + "properties": { + "alg": { + "description": "Algorithm", + "type": "string" + }, + "e": { + "description": "Exponent (Base64url encoded)", + "type": "string" + }, + "kid": { + "description": "Key ID", + "type": "string" + }, + "kty": { + "description": "Key Type", + "type": "string" + }, + "n": { + "description": "Modulus (Base64url encoded)", + "type": "string" + }, + "use": { + "description": "Public Key Use (e.g., \"sig\")", + "type": "string" + } + } + }, + "github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWK" + } + } + } + }, "github_com_roshankumar0036singh_auth-server_internal_dto.LoginRequest": { "type": "object", "required": [ diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 8fb90f0..f68bc53 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -18,6 +18,34 @@ definitions: required: - email type: object + github_com_roshankumar0036singh_auth-server_internal_dto.JWK: + properties: + alg: + description: Algorithm + type: string + e: + description: Exponent (Base64url encoded) + type: string + kid: + description: Key ID + type: string + kty: + description: Key Type + type: string + "n": + description: Modulus (Base64url encoded) + type: string + use: + description: Public Key Use (e.g., "sig") + type: string + type: object + github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse: + properties: + keys: + items: + $ref: '#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWK' + type: array + type: object github_com_roshankumar0036singh_auth-server_internal_dto.LoginRequest: properties: email: @@ -285,6 +313,19 @@ info: title: Auth Server API version: "1.0" paths: + /.well-known/jwks.json: + get: + description: Returns the public keys used to verify JWTs issued by this server + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse' + summary: Get JSON Web Key Set + tags: + - OpenID Connect /api/admin/users: get: produces: diff --git a/internal/config/config.go b/internal/config/config.go index 98dd5d6..0d887b5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -194,13 +194,15 @@ func loadRSAKeys() (*rsa.PrivateKey, *rsa.PublicKey) { privBytes, err1 := os.ReadFile(privPath) pubBytes, err2 := os.ReadFile(pubPath) - if err1 != nil || err2 != nil { + if os.IsNotExist(err1) || os.IsNotExist(err2) { log.Println("RSA keys not found at provided paths, generating temporary in-memory keys for development/testing...") privKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { log.Fatalf("Failed to generate temp RSA key: %v", err) } return privKey, &privKey.PublicKey + } else if err1 != nil || err2 != nil { + log.Fatalf("Failed to read RSA keys: privErr=%v, pubErr=%v", err1, err2) } privKey, err := jwt.ParseRSAPrivateKeyFromPEM(privBytes) diff --git a/internal/handler/auth_handler_protected_test.go b/internal/handler/auth_handler_protected_test.go index 39a0cb6..b71eaa0 100644 --- a/internal/handler/auth_handler_protected_test.go +++ b/internal/handler/auth_handler_protected_test.go @@ -23,7 +23,8 @@ func TestAuthHandler_GetMe(t *testing.T) { defer mr.Close() authHandler := handler.NewAuthHandler(authService, nil, nil) // We need TokenService to create a valid token for the middleware - cfg := &config.Config{JWT: config.JWTConfig{AccessSecret: "secret"}} + priv, pub := testutils.GetTestRSAKeys(t) + cfg := &config.Config{JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}} tokenService := service.NewTokenService(cfg) rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) diff --git a/internal/handler/auth_handler_test.go b/internal/handler/auth_handler_test.go index 8c395c2..3759ad1 100644 --- a/internal/handler/auth_handler_test.go +++ b/internal/handler/auth_handler_test.go @@ -106,10 +106,12 @@ func TestAuthHandler_GetSessions_CurrentSessionFlag(t *testing.T) { authHandler := NewAuthHandler(authService, nil, nil) + priv, pub := testutils.GetTestRSAKeys(t) cfg := &config.Config{ JWT: config.JWTConfig{ - AccessSecret: "secret", - RefreshSecret: "refresh-secret", + PrivateKey: priv, + PublicKey: pub, + KeyID: "test-key", }, } tokenService := service.NewTokenService(cfg) diff --git a/internal/handler/jwks_handler.go b/internal/handler/jwks_handler.go index 486d7bd..0ce60a7 100644 --- a/internal/handler/jwks_handler.go +++ b/internal/handler/jwks_handler.go @@ -29,7 +29,7 @@ func (h *JWKSHandler) GetJWKS(c *gin.Context) { pubKey := h.cfg.JWT.PublicKey if pubKey == nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "public key not configured"}) + c.JSON(http.StatusInternalServerError, gin.H{"error": "public key not configured", "code": "INTERNAL_SERVER_ERROR"}) return } diff --git a/internal/handler/oauth_handler_test.go b/internal/handler/oauth_handler_test.go index 9863331..3a29ec1 100644 --- a/internal/handler/oauth_handler_test.go +++ b/internal/handler/oauth_handler_test.go @@ -29,6 +29,7 @@ func setupOAuthUserInfoRouter(t *testing.T) (*gin.Engine, *repository.UserReposi userRepo := repository.NewUserRepository(db) tokenRepo := repository.NewOAuthTokenRepository(db) + priv, pub := testutils.GetTestRSAKeys(t) oauthProviderService := service.NewOAuthProviderService( repository.NewOAuthClientRepository(db), repository.NewAuthorizationCodeRepository(db), @@ -36,7 +37,7 @@ func setupOAuthUserInfoRouter(t *testing.T) (*gin.Engine, *repository.UserReposi repository.NewUserConsentRepository(db), repository.NewOAuthProviderConfigRepository(db), service.NewTokenService(&config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}, }), &config.Config{}, ) @@ -76,6 +77,7 @@ func TestNewOAuthHandlerPanicsWithoutUserRepository(t *testing.T) { t.Cleanup(func() { mr.Close() }) tokenRepo := repository.NewOAuthTokenRepository(db) + priv, pub := testutils.GetTestRSAKeys(t) oauthProviderService := service.NewOAuthProviderService( repository.NewOAuthClientRepository(db), repository.NewAuthorizationCodeRepository(db), @@ -83,7 +85,7 @@ func TestNewOAuthHandlerPanicsWithoutUserRepository(t *testing.T) { repository.NewUserConsentRepository(db), repository.NewOAuthProviderConfigRepository(db), service.NewTokenService(&config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}, }), &config.Config{}, ) @@ -358,6 +360,7 @@ func setupTokenRouter(t *testing.T) (*gin.Engine, *repository.OAuthClientReposit tokenRepo := repository.NewOAuthTokenRepository(db) userRepo := repository.NewUserRepository(db) + priv, pub := testutils.GetTestRSAKeys(t) oauthProviderService := service.NewOAuthProviderService( clientRepo, codeRepo, @@ -365,7 +368,7 @@ func setupTokenRouter(t *testing.T) (*gin.Engine, *repository.OAuthClientReposit repository.NewUserConsentRepository(db), repository.NewOAuthProviderConfigRepository(db), service.NewTokenService(&config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}, }), &config.Config{}, ) diff --git a/internal/middleware/auth_test.go b/internal/middleware/auth_test.go index cc8bff1..6012365 100644 --- a/internal/middleware/auth_test.go +++ b/internal/middleware/auth_test.go @@ -14,6 +14,7 @@ import ( "github.com/roshankumar0036singh/auth-server/internal/middleware" "github.com/roshankumar0036singh/auth-server/internal/models" "github.com/roshankumar0036singh/auth-server/internal/service" + "github.com/roshankumar0036singh/auth-server/internal/testutils" "github.com/stretchr/testify/assert" ) @@ -24,8 +25,9 @@ func setupTest(t *testing.T) (*gin.Engine, *service.TokenService, *service.Cache rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) t.Cleanup(func() { _ = rdb.Close() }) + priv, pub := testutils.GetTestRSAKeys(t) cfg := &config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh", AccessExpiry: "15m"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key", AccessExpiry: "15m"}, } tokenService := service.NewTokenService(cfg) cacheService := service.NewCacheService(rdb) diff --git a/internal/service/security_phase1_test.go b/internal/service/security_phase1_test.go index 9b51e80..fdcef10 100644 --- a/internal/service/security_phase1_test.go +++ b/internal/service/security_phase1_test.go @@ -18,9 +18,10 @@ import ( // testCfg mirrors the secrets used by testutils.SetupIntegrationTest so a // TokenService built here produces tokens the integration AuthService accepts. -func testCfg() *config.Config { +func testCfg(t *testing.T) *config.Config { + priv, pub := testutils.GetTestRSAKeys(t) return &config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}, Security: config.SecurityConfig{RateLimitMax: 10, EncryptionKey: "12345678901234567890123456789012"}, App: config.AppConfig{URL: "http://localhost"}, } @@ -37,8 +38,8 @@ func newProviderService(t *testing.T) (*service.OAuthProviderService, *repositor tokenRepo, repository.NewUserConsentRepository(db), repository.NewOAuthProviderConfigRepository(db), - service.NewTokenService(testCfg()), - testCfg(), + service.NewTokenService(testCfg(t)), + testCfg(t), ) return ps, tokenRepo } diff --git a/internal/service/token_service_test.go b/internal/service/token_service_test.go index 01650f9..7170557 100644 --- a/internal/service/token_service_test.go +++ b/internal/service/token_service_test.go @@ -1,23 +1,17 @@ package service_test import ( - "crypto/rand" - "crypto/rsa" "testing" "github.com/roshankumar0036singh/auth-server/internal/config" "github.com/roshankumar0036singh/auth-server/internal/models" "github.com/roshankumar0036singh/auth-server/internal/service" + "github.com/roshankumar0036singh/auth-server/internal/testutils" "github.com/stretchr/testify/assert" ) -func getTestRSAKeys() (*rsa.PrivateKey, *rsa.PublicKey) { - priv, _ := rsa.GenerateKey(rand.Reader, 2048) - return priv, &priv.PublicKey -} - func TestTokenService_GenerateAccessToken(t *testing.T) { - priv, pub := getTestRSAKeys() + priv, pub := testutils.GetTestRSAKeys(t) cfg := &config.Config{ JWT: config.JWTConfig{ PrivateKey: priv, @@ -48,7 +42,7 @@ func TestTokenService_GenerateAccessToken(t *testing.T) { } func TestTokenService_GenerateRefreshToken(t *testing.T) { - priv, pub := getTestRSAKeys() + priv, pub := testutils.GetTestRSAKeys(t) cfg := &config.Config{ JWT: config.JWTConfig{ PrivateKey: priv, diff --git a/internal/testutils/setup.go b/internal/testutils/setup.go index bfffd4d..714fba0 100644 --- a/internal/testutils/setup.go +++ b/internal/testutils/setup.go @@ -1,6 +1,9 @@ package testutils import ( + "crypto/rand" + "crypto/rsa" + "sync" "testing" "github.com/alicebob/miniredis/v2" @@ -14,6 +17,24 @@ import ( "gorm.io/gorm" ) +var ( + testPrivKey *rsa.PrivateKey + testPubKey *rsa.PublicKey + testKeyOnce sync.Once +) + +func GetTestRSAKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) { + testKeyOnce.Do(func() { + var err error + testPrivKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate test RSA key: %v", err) + } + testPubKey = &testPrivKey.PublicKey + }) + return testPrivKey, testPubKey +} + // MockEmailSender type MockEmailSender struct { LastEmail map[string]string @@ -109,8 +130,9 @@ func SetupIntegrationTest(t *testing.T) (*service.AuthService, *gorm.DB, *minire auditRepo := repository.NewAuditRepository(db) // 4. Services + priv, pub := GetTestRSAKeys(t) cfg := &config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}, Security: config.SecurityConfig{RateLimitMax: 10, RateLimitWindow: 60}, App: config.AppConfig{URL: "http://localhost"}, } From 636bc70d224db949ecb300ea9f0b783cfe6d99a7 Mon Sep 17 00:00:00 2001 From: Tharun N V Date: Sat, 20 Jun 2026 16:44:02 +0530 Subject: [PATCH 4/4] refactor: extract duplicate token validation and signing logic --- internal/service/token_service.go | 102 ++++++++++++------------------ 1 file changed, 41 insertions(+), 61 deletions(-) diff --git a/internal/service/token_service.go b/internal/service/token_service.go index c7d5c07..1a7021e 100644 --- a/internal/service/token_service.go +++ b/internal/service/token_service.go @@ -42,6 +42,31 @@ const ( errInvalidToken = "invalid token" ) +func (s *TokenService) signToken(claims *JWTClaims) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = s.cfg.JWT.KeyID + return token.SignedString(s.cfg.JWT.PrivateKey) +} + +func (s *TokenService) parseToken(tokenString string) (*JWTClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, errors.New(errInvalidSignMethod) + } + return s.cfg.JWT.PublicKey, nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { + return claims, nil + } + + return nil, errors.New(errInvalidToken) +} + // GenerateAccessToken generates a new JWT access token func (s *TokenService) GenerateAccessToken(user *models.User, sessionID string) (string, error) { expirationTime := time.Now().Add(15 * time.Minute) // 15 minutes @@ -59,14 +84,7 @@ func (s *TokenService) GenerateAccessToken(user *models.User, sessionID string) }, } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header["kid"] = s.cfg.JWT.KeyID - tokenString, err := token.SignedString(s.cfg.JWT.PrivateKey) - if err != nil { - return "", err - } - - return tokenString, nil + return s.signToken(claims) } // GenerateRefreshToken generates a new refresh token (longer expiry) @@ -85,39 +103,23 @@ func (s *TokenService) GenerateRefreshToken(user *models.User) (string, error) { }, } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header["kid"] = s.cfg.JWT.KeyID - tokenString, err := token.SignedString(s.cfg.JWT.PrivateKey) - if err != nil { - return "", err - } - - return tokenString, nil + return s.signToken(claims) } // ValidateAccessToken validates and parses an access token func (s *TokenService) ValidateAccessToken(tokenString string) (*JWTClaims, error) { - token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - return nil, errors.New(errInvalidSignMethod) - } - return s.cfg.JWT.PublicKey, nil - }) - + claims, err := s.parseToken(tokenString) if err != nil { return nil, err } - - if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { - // Purpose-scoped tokens (e.g. the MFA-pending token) must never be - // accepted as access tokens. - if claims.Purpose != "" { - return nil, errors.New(errInvalidToken) - } - return claims, nil + + // Purpose-scoped tokens (e.g. the MFA-pending token) must never be + // accepted as access tokens. + if claims.Purpose != "" { + return nil, errors.New(errInvalidToken) } - - return nil, errors.New(errInvalidToken) + + return claims, nil } // GenerateMFAToken issues a short-lived token proving the password step of @@ -135,48 +137,26 @@ func (s *TokenService) GenerateMFAToken(userID string) (string, error) { }, } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header["kid"] = s.cfg.JWT.KeyID - return token.SignedString(s.cfg.JWT.PrivateKey) + return s.signToken(claims) } // ValidateMFAToken validates an MFA-pending token and returns the user ID it // was issued for. It rejects any token whose Purpose is not the MFA-pending // marker, so access/refresh tokens cannot stand in for it. func (s *TokenService) ValidateMFAToken(tokenString string) (string, error) { - token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - return nil, errors.New(errInvalidSignMethod) - } - return s.cfg.JWT.PublicKey, nil - }) + claims, err := s.parseToken(tokenString) if err != nil { return "", err } - - claims, ok := token.Claims.(*JWTClaims) - if !ok || !token.Valid || claims.Purpose != mfaPendingPurpose { + + if claims.Purpose != mfaPendingPurpose { return "", errors.New("invalid mfa token") } + return claims.UserID, nil } // ValidateRefreshToken validates and parses a refresh token func (s *TokenService) ValidateRefreshToken(tokenString string) (*JWTClaims, error) { - token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - return nil, errors.New(errInvalidSignMethod) - } - return s.cfg.JWT.PublicKey, nil - }) - - if err != nil { - return nil, err - } - - if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { - return claims, nil - } - - return nil, errors.New(errInvalidToken) + return s.parseToken(tokenString) }