diff --git a/docs/docs.go b/docs/docs.go index ee2772d..c77e274 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -23,6 +23,26 @@ const docTemplate = `{ "host": "{{.Host}}", "basePath": "{{.BasePath}}", "paths": { + "/.well-known/jwks.json": { + "get": { + "description": "Returns the public keys used to verify JWTs issued by this server", + "produces": [ + "application/json" + ], + "tags": [ + "OpenID Connect" + ], + "summary": "Get JSON Web Key Set", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse" + } + } + } + } + }, "/api/admin/users": { "get": { "security": [ @@ -1405,6 +1425,46 @@ const docTemplate = `{ } } }, + "github_com_roshankumar0036singh_auth-server_internal_dto.JWK": { + "type": "object", + "properties": { + "alg": { + "description": "Algorithm", + "type": "string" + }, + "e": { + "description": "Exponent (Base64url encoded)", + "type": "string" + }, + "kid": { + "description": "Key ID", + "type": "string" + }, + "kty": { + "description": "Key Type", + "type": "string" + }, + "n": { + "description": "Modulus (Base64url encoded)", + "type": "string" + }, + "use": { + "description": "Public Key Use (e.g., \"sig\")", + "type": "string" + } + } + }, + "github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWK" + } + } + } + }, "github_com_roshankumar0036singh_auth-server_internal_dto.LoginRequest": { "type": "object", "required": [ diff --git a/docs/swagger.json b/docs/swagger.json index 6486388..ef06bae 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -21,6 +21,26 @@ "host": "localhost:8080", "basePath": "/", "paths": { + "/.well-known/jwks.json": { + "get": { + "description": "Returns the public keys used to verify JWTs issued by this server", + "produces": [ + "application/json" + ], + "tags": [ + "OpenID Connect" + ], + "summary": "Get JSON Web Key Set", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse" + } + } + } + } + }, "/api/admin/users": { "get": { "security": [ @@ -1403,6 +1423,46 @@ } } }, + "github_com_roshankumar0036singh_auth-server_internal_dto.JWK": { + "type": "object", + "properties": { + "alg": { + "description": "Algorithm", + "type": "string" + }, + "e": { + "description": "Exponent (Base64url encoded)", + "type": "string" + }, + "kid": { + "description": "Key ID", + "type": "string" + }, + "kty": { + "description": "Key Type", + "type": "string" + }, + "n": { + "description": "Modulus (Base64url encoded)", + "type": "string" + }, + "use": { + "description": "Public Key Use (e.g., \"sig\")", + "type": "string" + } + } + }, + "github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWK" + } + } + } + }, "github_com_roshankumar0036singh_auth-server_internal_dto.LoginRequest": { "type": "object", "required": [ diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 8fb90f0..f68bc53 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -18,6 +18,34 @@ definitions: required: - email type: object + github_com_roshankumar0036singh_auth-server_internal_dto.JWK: + properties: + alg: + description: Algorithm + type: string + e: + description: Exponent (Base64url encoded) + type: string + kid: + description: Key ID + type: string + kty: + description: Key Type + type: string + "n": + description: Modulus (Base64url encoded) + type: string + use: + description: Public Key Use (e.g., "sig") + type: string + type: object + github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse: + properties: + keys: + items: + $ref: '#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWK' + type: array + type: object github_com_roshankumar0036singh_auth-server_internal_dto.LoginRequest: properties: email: @@ -285,6 +313,19 @@ info: title: Auth Server API version: "1.0" paths: + /.well-known/jwks.json: + get: + description: Returns the public keys used to verify JWTs issued by this server + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/github_com_roshankumar0036singh_auth-server_internal_dto.JWKSResponse' + summary: Get JSON Web Key Set + tags: + - OpenID Connect /api/admin/users: get: produces: diff --git a/internal/config/config.go b/internal/config/config.go index b1f5eab..0d887b5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,10 +1,13 @@ package config import ( + "crypto/rand" + "crypto/rsa" "log" "os" "strconv" + "github.com/golang-jwt/jwt/v5" "github.com/joho/godotenv" ) @@ -36,11 +39,12 @@ type RedisConfig struct { } type JWTConfig struct { - AccessSecret string - RefreshSecret string - AccessExpiry string - RefreshExpiry string - RefreshGracePeriod string + PrivateKey *rsa.PrivateKey + PublicKey *rsa.PublicKey + KeyID string + AccessExpiry string + RefreshExpiry string + RefreshGracePeriod string } type OAuthConfig struct { Google GoogleOAuthConfig @@ -112,14 +116,8 @@ func LoadConfig() *Config { appURL := getEnv("APP_URL", "http://localhost:3000") - accessSecret := getEnv("JWT_SECRET", "") - refreshSecret := getEnv("JWT_REFRESH_SECRET", "") - if len(accessSecret) < 32 { - log.Fatal("JWT_SECRET must be set and at least 32 bytes long") - } - if len(refreshSecret) < 32 { - log.Fatal("JWT_REFRESH_SECRET must be set and at least 32 bytes long") - } + privKey, pubKey := loadRSAKeys() + keyID := getEnv("JWT_KEY_ID", "default-key-1") encKey := getEnv("ENCRYPTION_KEY", "") if encKey == "" || encKey == "0123456789abcdef0123456789abcdef" { @@ -142,8 +140,9 @@ func LoadConfig() *Config { TTL: redisTTL, }, JWT: JWTConfig{ - AccessSecret: getEnv("JWT_SECRET", ""), - RefreshSecret: getEnv("JWT_REFRESH_SECRET", ""), + PrivateKey: privKey, + PublicKey: pubKey, + KeyID: keyID, AccessExpiry: getEnv("JWT_ACCESS_EXPIRY", "15m"), RefreshExpiry: getEnv("JWT_REFRESH_EXPIRY", "168h"), RefreshGracePeriod: getEnv("JWT_REFRESH_GRACE_PERIOD", "10s"), @@ -187,3 +186,34 @@ func getEnv(key, defaultValue string) string { } return defaultValue } + +func loadRSAKeys() (*rsa.PrivateKey, *rsa.PublicKey) { + privPath := getEnv("JWT_PRIVATE_KEY_PATH", "private.pem") + pubPath := getEnv("JWT_PUBLIC_KEY_PATH", "public.pem") + + privBytes, err1 := os.ReadFile(privPath) + pubBytes, err2 := os.ReadFile(pubPath) + + if os.IsNotExist(err1) || os.IsNotExist(err2) { + log.Println("RSA keys not found at provided paths, generating temporary in-memory keys for development/testing...") + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + log.Fatalf("Failed to generate temp RSA key: %v", err) + } + return privKey, &privKey.PublicKey + } else if err1 != nil || err2 != nil { + log.Fatalf("Failed to read RSA keys: privErr=%v, pubErr=%v", err1, err2) + } + + privKey, err := jwt.ParseRSAPrivateKeyFromPEM(privBytes) + if err != nil { + log.Fatalf("Failed to parse private key: %v", err) + } + + pubKey, err := jwt.ParseRSAPublicKeyFromPEM(pubBytes) + if err != nil { + log.Fatalf("Failed to parse public key: %v", err) + } + + return privKey, pubKey +} diff --git a/internal/dto/jwks.go b/internal/dto/jwks.go new file mode 100644 index 0000000..ee2dc32 --- /dev/null +++ b/internal/dto/jwks.go @@ -0,0 +1,16 @@ +package dto + +// JWK represents a single JSON Web Key +type JWK struct { + Kty string `json:"kty"` // Key Type + Alg string `json:"alg"` // Algorithm + Use string `json:"use"` // Public Key Use (e.g., "sig") + Kid string `json:"kid"` // Key ID + N string `json:"n"` // Modulus (Base64url encoded) + E string `json:"e"` // Exponent (Base64url encoded) +} + +// JWKSResponse represents a JSON Web Key Set +type JWKSResponse struct { + Keys []JWK `json:"keys"` +} diff --git a/internal/dto/oidc_discovery.go b/internal/dto/oidc_discovery.go new file mode 100644 index 0000000..0f0474e --- /dev/null +++ b/internal/dto/oidc_discovery.go @@ -0,0 +1,15 @@ +package dto + +// OIDCDiscoveryResponse represents the standard OpenID Connect discovery configuration +type OIDCDiscoveryResponse struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + JwksURI string `json:"jwks_uri"` + ScopesSupported []string `json:"scopes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` +} diff --git a/internal/handler/auth_handler_protected_test.go b/internal/handler/auth_handler_protected_test.go index 39a0cb6..b71eaa0 100644 --- a/internal/handler/auth_handler_protected_test.go +++ b/internal/handler/auth_handler_protected_test.go @@ -23,7 +23,8 @@ func TestAuthHandler_GetMe(t *testing.T) { defer mr.Close() authHandler := handler.NewAuthHandler(authService, nil, nil) // We need TokenService to create a valid token for the middleware - cfg := &config.Config{JWT: config.JWTConfig{AccessSecret: "secret"}} + priv, pub := testutils.GetTestRSAKeys(t) + cfg := &config.Config{JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}} tokenService := service.NewTokenService(cfg) rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) diff --git a/internal/handler/auth_handler_test.go b/internal/handler/auth_handler_test.go index 8c395c2..3759ad1 100644 --- a/internal/handler/auth_handler_test.go +++ b/internal/handler/auth_handler_test.go @@ -106,10 +106,12 @@ func TestAuthHandler_GetSessions_CurrentSessionFlag(t *testing.T) { authHandler := NewAuthHandler(authService, nil, nil) + priv, pub := testutils.GetTestRSAKeys(t) cfg := &config.Config{ JWT: config.JWTConfig{ - AccessSecret: "secret", - RefreshSecret: "refresh-secret", + PrivateKey: priv, + PublicKey: pub, + KeyID: "test-key", }, } tokenService := service.NewTokenService(cfg) diff --git a/internal/handler/jwks_handler.go b/internal/handler/jwks_handler.go new file mode 100644 index 0000000..0ce60a7 --- /dev/null +++ b/internal/handler/jwks_handler.go @@ -0,0 +1,56 @@ +package handler + +import ( + "encoding/base64" + "math/big" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/roshankumar0036singh/auth-server/internal/config" + "github.com/roshankumar0036singh/auth-server/internal/dto" +) + +type JWKSHandler struct { + cfg *config.Config +} + +func NewJWKSHandler(cfg *config.Config) *JWKSHandler { + return &JWKSHandler{cfg: cfg} +} + +// GetJWKS returns the public keys in JWKS format +// @Summary Get JSON Web Key Set +// @Description Returns the public keys used to verify JWTs issued by this server +// @Tags OpenID Connect +// @Produce json +// @Success 200 {object} dto.JWKSResponse +// @Router /.well-known/jwks.json [get] +func (h *JWKSHandler) GetJWKS(c *gin.Context) { + pubKey := h.cfg.JWT.PublicKey + + if pubKey == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "public key not configured", "code": "INTERNAL_SERVER_ERROR"}) + return + } + + // The modulus N needs to be BigEndian bytes, encoded as Base64Url (without padding) + nBytes := pubKey.N.Bytes() + nStr := base64.RawURLEncoding.EncodeToString(nBytes) + + // Exponent E is an int, need to encode to bytes then Base64Url + eBytes := big.NewInt(int64(pubKey.E)).Bytes() + eStr := base64.RawURLEncoding.EncodeToString(eBytes) + + jwk := dto.JWK{ + Kty: "RSA", + Alg: "RS256", + Use: "sig", + Kid: h.cfg.JWT.KeyID, + N: nStr, + E: eStr, + } + + c.JSON(http.StatusOK, dto.JWKSResponse{ + Keys: []dto.JWK{jwk}, + }) +} diff --git a/internal/handler/oauth_handler_test.go b/internal/handler/oauth_handler_test.go index 9863331..3a29ec1 100644 --- a/internal/handler/oauth_handler_test.go +++ b/internal/handler/oauth_handler_test.go @@ -29,6 +29,7 @@ func setupOAuthUserInfoRouter(t *testing.T) (*gin.Engine, *repository.UserReposi userRepo := repository.NewUserRepository(db) tokenRepo := repository.NewOAuthTokenRepository(db) + priv, pub := testutils.GetTestRSAKeys(t) oauthProviderService := service.NewOAuthProviderService( repository.NewOAuthClientRepository(db), repository.NewAuthorizationCodeRepository(db), @@ -36,7 +37,7 @@ func setupOAuthUserInfoRouter(t *testing.T) (*gin.Engine, *repository.UserReposi repository.NewUserConsentRepository(db), repository.NewOAuthProviderConfigRepository(db), service.NewTokenService(&config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}, }), &config.Config{}, ) @@ -76,6 +77,7 @@ func TestNewOAuthHandlerPanicsWithoutUserRepository(t *testing.T) { t.Cleanup(func() { mr.Close() }) tokenRepo := repository.NewOAuthTokenRepository(db) + priv, pub := testutils.GetTestRSAKeys(t) oauthProviderService := service.NewOAuthProviderService( repository.NewOAuthClientRepository(db), repository.NewAuthorizationCodeRepository(db), @@ -83,7 +85,7 @@ func TestNewOAuthHandlerPanicsWithoutUserRepository(t *testing.T) { repository.NewUserConsentRepository(db), repository.NewOAuthProviderConfigRepository(db), service.NewTokenService(&config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}, }), &config.Config{}, ) @@ -358,6 +360,7 @@ func setupTokenRouter(t *testing.T) (*gin.Engine, *repository.OAuthClientReposit tokenRepo := repository.NewOAuthTokenRepository(db) userRepo := repository.NewUserRepository(db) + priv, pub := testutils.GetTestRSAKeys(t) oauthProviderService := service.NewOAuthProviderService( clientRepo, codeRepo, @@ -365,7 +368,7 @@ func setupTokenRouter(t *testing.T) (*gin.Engine, *repository.OAuthClientReposit repository.NewUserConsentRepository(db), repository.NewOAuthProviderConfigRepository(db), service.NewTokenService(&config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}, }), &config.Config{}, ) diff --git a/internal/handler/oidc_discovery_handler.go b/internal/handler/oidc_discovery_handler.go new file mode 100644 index 0000000..e91694d --- /dev/null +++ b/internal/handler/oidc_discovery_handler.go @@ -0,0 +1,58 @@ +package handler + +import ( + "fmt" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/roshankumar0036singh/auth-server/internal/config" + "github.com/roshankumar0036singh/auth-server/internal/dto" +) + +type OIDCDiscoveryHandler struct { + cfg *config.Config +} + +func NewOIDCDiscoveryHandler(cfg *config.Config) *OIDCDiscoveryHandler { + return &OIDCDiscoveryHandler{cfg: cfg} +} + +// GetConfiguration returns the OpenID Connect discovery document +// @Summary Get OpenID Connect Configuration +// @Description Returns the OpenID Connect discovery metadata document +// @Tags OpenID Connect +// @Produce json +// @Success 200 {object} dto.OIDCDiscoveryResponse +// @Router /.well-known/openid-configuration [get] +func (h *OIDCDiscoveryHandler) GetConfiguration(c *gin.Context) { + baseURL := strings.TrimRight(h.cfg.App.URL, "/") + + response := dto.OIDCDiscoveryResponse{ + Issuer: baseURL, + AuthorizationEndpoint: fmt.Sprintf("%s/oauth/authorize", baseURL), + TokenEndpoint: fmt.Sprintf("%s/oauth/token", baseURL), + UserinfoEndpoint: fmt.Sprintf("%s/oauth/userinfo", baseURL), + JwksURI: fmt.Sprintf("%s/.well-known/jwks.json", baseURL), + ScopesSupported: []string{ + "openid", + "profile", + "email", + }, + ResponseTypesSupported: []string{ + "code", + }, + GrantTypesSupported: []string{ + "authorization_code", + "refresh_token", + }, + SubjectTypesSupported: []string{ + "public", + }, + IDTokenSigningAlgValuesSupported: []string{ + "RS256", + }, + } + + c.JSON(http.StatusOK, response) +} diff --git a/internal/middleware/auth_test.go b/internal/middleware/auth_test.go index cc8bff1..6012365 100644 --- a/internal/middleware/auth_test.go +++ b/internal/middleware/auth_test.go @@ -14,6 +14,7 @@ import ( "github.com/roshankumar0036singh/auth-server/internal/middleware" "github.com/roshankumar0036singh/auth-server/internal/models" "github.com/roshankumar0036singh/auth-server/internal/service" + "github.com/roshankumar0036singh/auth-server/internal/testutils" "github.com/stretchr/testify/assert" ) @@ -24,8 +25,9 @@ func setupTest(t *testing.T) (*gin.Engine, *service.TokenService, *service.Cache rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) t.Cleanup(func() { _ = rdb.Close() }) + priv, pub := testutils.GetTestRSAKeys(t) cfg := &config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh", AccessExpiry: "15m"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key", AccessExpiry: "15m"}, } tokenService := service.NewTokenService(cfg) cacheService := service.NewCacheService(rdb) diff --git a/internal/routes/routes.go b/internal/routes/routes.go index 4bb32b9..fd010cf 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -71,6 +71,8 @@ func SetupRoutes(router *gin.Engine, db *gorm.DB, redisClient *redis.Client, cfg adminHandler := handler.NewAdminHandler(authService) oauthClientHandler := handler.NewOAuthClientHandler(oauthProviderService) oauthHandler := handler.NewOAuthHandler(oauthProviderService, userRepo) + jwksHandler := handler.NewJWKSHandler(cfg) + discoveryHandler := handler.NewOIDCDiscoveryHandler(cfg) // Apply global middleware router.Use(middleware.CORSMiddleware(cfg)) @@ -121,6 +123,14 @@ func SetupRoutes(router *gin.Engine, db *gorm.DB, redisClient *redis.Client, cfg }) }) + // JWKS endpoint + router.GET("/.well-known/jwks.json", jwksHandler.GetJWKS) + + // OIDC Discovery endpoint + router.GET("/.well-known/openid-configuration", discoveryHandler.GetConfiguration) + + + // OAuth 2.0 Provider endpoints router.GET("/oauth/authorize", middleware.OptionalAuthMiddleware(tokenService, cacheService), oauthHandler.Authorize) router.POST("/oauth/authorize", middleware.AuthMiddleware(tokenService, cacheService), oauthHandler.AuthorizePost) diff --git a/internal/service/security_phase1_test.go b/internal/service/security_phase1_test.go index 9b51e80..fdcef10 100644 --- a/internal/service/security_phase1_test.go +++ b/internal/service/security_phase1_test.go @@ -18,9 +18,10 @@ import ( // testCfg mirrors the secrets used by testutils.SetupIntegrationTest so a // TokenService built here produces tokens the integration AuthService accepts. -func testCfg() *config.Config { +func testCfg(t *testing.T) *config.Config { + priv, pub := testutils.GetTestRSAKeys(t) return &config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}, Security: config.SecurityConfig{RateLimitMax: 10, EncryptionKey: "12345678901234567890123456789012"}, App: config.AppConfig{URL: "http://localhost"}, } @@ -37,8 +38,8 @@ func newProviderService(t *testing.T) (*service.OAuthProviderService, *repositor tokenRepo, repository.NewUserConsentRepository(db), repository.NewOAuthProviderConfigRepository(db), - service.NewTokenService(testCfg()), - testCfg(), + service.NewTokenService(testCfg(t)), + testCfg(t), ) return ps, tokenRepo } diff --git a/internal/service/token_service.go b/internal/service/token_service.go index d66a349..1a7021e 100644 --- a/internal/service/token_service.go +++ b/internal/service/token_service.go @@ -42,6 +42,31 @@ const ( errInvalidToken = "invalid token" ) +func (s *TokenService) signToken(claims *JWTClaims) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = s.cfg.JWT.KeyID + return token.SignedString(s.cfg.JWT.PrivateKey) +} + +func (s *TokenService) parseToken(tokenString string) (*JWTClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, errors.New(errInvalidSignMethod) + } + return s.cfg.JWT.PublicKey, nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { + return claims, nil + } + + return nil, errors.New(errInvalidToken) +} + // GenerateAccessToken generates a new JWT access token func (s *TokenService) GenerateAccessToken(user *models.User, sessionID string) (string, error) { expirationTime := time.Now().Add(15 * time.Minute) // 15 minutes @@ -59,13 +84,7 @@ func (s *TokenService) GenerateAccessToken(user *models.User, sessionID string) }, } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString([]byte(s.cfg.JWT.AccessSecret)) - if err != nil { - return "", err - } - - return tokenString, nil + return s.signToken(claims) } // GenerateRefreshToken generates a new refresh token (longer expiry) @@ -84,38 +103,23 @@ func (s *TokenService) GenerateRefreshToken(user *models.User) (string, error) { }, } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString([]byte(s.cfg.JWT.RefreshSecret)) - if err != nil { - return "", err - } - - return tokenString, nil + return s.signToken(claims) } // ValidateAccessToken validates and parses an access token func (s *TokenService) ValidateAccessToken(tokenString string) (*JWTClaims, error) { - token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, errors.New(errInvalidSignMethod) - } - return []byte(s.cfg.JWT.AccessSecret), nil - }) - + claims, err := s.parseToken(tokenString) if err != nil { return nil, err } - - if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { - // Purpose-scoped tokens (e.g. the MFA-pending token) must never be - // accepted as access tokens. - if claims.Purpose != "" { - return nil, errors.New(errInvalidToken) - } - return claims, nil + + // Purpose-scoped tokens (e.g. the MFA-pending token) must never be + // accepted as access tokens. + if claims.Purpose != "" { + return nil, errors.New(errInvalidToken) } - - return nil, errors.New(errInvalidToken) + + return claims, nil } // GenerateMFAToken issues a short-lived token proving the password step of @@ -133,47 +137,26 @@ func (s *TokenService) GenerateMFAToken(userID string) (string, error) { }, } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString([]byte(s.cfg.JWT.AccessSecret)) + return s.signToken(claims) } // ValidateMFAToken validates an MFA-pending token and returns the user ID it // was issued for. It rejects any token whose Purpose is not the MFA-pending // marker, so access/refresh tokens cannot stand in for it. func (s *TokenService) ValidateMFAToken(tokenString string) (string, error) { - token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, errors.New(errInvalidSignMethod) - } - return []byte(s.cfg.JWT.AccessSecret), nil - }) + claims, err := s.parseToken(tokenString) if err != nil { return "", err } - - claims, ok := token.Claims.(*JWTClaims) - if !ok || !token.Valid || claims.Purpose != mfaPendingPurpose { + + if claims.Purpose != mfaPendingPurpose { return "", errors.New("invalid mfa token") } + return claims.UserID, nil } // ValidateRefreshToken validates and parses a refresh token func (s *TokenService) ValidateRefreshToken(tokenString string) (*JWTClaims, error) { - token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, errors.New(errInvalidSignMethod) - } - return []byte(s.cfg.JWT.RefreshSecret), nil - }) - - if err != nil { - return nil, err - } - - if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { - return claims, nil - } - - return nil, errors.New(errInvalidToken) + return s.parseToken(tokenString) } diff --git a/internal/service/token_service_test.go b/internal/service/token_service_test.go index 1a085b0..7170557 100644 --- a/internal/service/token_service_test.go +++ b/internal/service/token_service_test.go @@ -6,14 +6,17 @@ import ( "github.com/roshankumar0036singh/auth-server/internal/config" "github.com/roshankumar0036singh/auth-server/internal/models" "github.com/roshankumar0036singh/auth-server/internal/service" + "github.com/roshankumar0036singh/auth-server/internal/testutils" "github.com/stretchr/testify/assert" ) func TestTokenService_GenerateAccessToken(t *testing.T) { + priv, pub := testutils.GetTestRSAKeys(t) cfg := &config.Config{ JWT: config.JWTConfig{ - AccessSecret: "test-secret", - RefreshSecret: "test-refresh-secret", + PrivateKey: priv, + PublicKey: pub, + KeyID: "test-key", }, } svc := service.NewTokenService(cfg) @@ -39,10 +42,12 @@ func TestTokenService_GenerateAccessToken(t *testing.T) { } func TestTokenService_GenerateRefreshToken(t *testing.T) { + priv, pub := testutils.GetTestRSAKeys(t) cfg := &config.Config{ JWT: config.JWTConfig{ - AccessSecret: "test-secret", - RefreshSecret: "test-refresh-secret", + PrivateKey: priv, + PublicKey: pub, + KeyID: "test-key", }, } svc := service.NewTokenService(cfg) diff --git a/internal/testutils/setup.go b/internal/testutils/setup.go index bfffd4d..714fba0 100644 --- a/internal/testutils/setup.go +++ b/internal/testutils/setup.go @@ -1,6 +1,9 @@ package testutils import ( + "crypto/rand" + "crypto/rsa" + "sync" "testing" "github.com/alicebob/miniredis/v2" @@ -14,6 +17,24 @@ import ( "gorm.io/gorm" ) +var ( + testPrivKey *rsa.PrivateKey + testPubKey *rsa.PublicKey + testKeyOnce sync.Once +) + +func GetTestRSAKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) { + testKeyOnce.Do(func() { + var err error + testPrivKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate test RSA key: %v", err) + } + testPubKey = &testPrivKey.PublicKey + }) + return testPrivKey, testPubKey +} + // MockEmailSender type MockEmailSender struct { LastEmail map[string]string @@ -109,8 +130,9 @@ func SetupIntegrationTest(t *testing.T) (*service.AuthService, *gorm.DB, *minire auditRepo := repository.NewAuditRepository(db) // 4. Services + priv, pub := GetTestRSAKeys(t) cfg := &config.Config{ - JWT: config.JWTConfig{AccessSecret: "secret", RefreshSecret: "refresh"}, + JWT: config.JWTConfig{PrivateKey: priv, PublicKey: pub, KeyID: "test-key"}, Security: config.SecurityConfig{RateLimitMax: 10, RateLimitWindow: 60}, App: config.AppConfig{URL: "http://localhost"}, }