diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9061a69..b6c7e88 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -48,7 +48,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.22" + go-version: "1.25" cache-dependency-path: services/api/go.sum - name: go vet diff --git a/services/api/go.mod b/services/api/go.mod index 1521f16..7b57f41 100644 --- a/services/api/go.mod +++ b/services/api/go.mod @@ -1,3 +1,5 @@ module github.com/Depo-dev/trident/services/api -go 1.22 +go 1.25.0 + +require golang.org/x/time v0.15.0 diff --git a/services/api/go.sum b/services/api/go.sum index e69de29..8ecb2d8 100644 --- a/services/api/go.sum +++ b/services/api/go.sum @@ -0,0 +1,2 @@ +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= diff --git a/services/api/middleware/auth.go b/services/api/middleware/auth.go new file mode 100644 index 0000000..91b71dc --- /dev/null +++ b/services/api/middleware/auth.go @@ -0,0 +1,63 @@ +package middleware + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "net/http" + "os" + "strings" +) + +// APIKey returns an HTTP middleware that validates the X-API-Key header on all +// /v1/* and /ws routes. GET /v1/health is exempt. +// +// The incoming key is HMAC-SHA256'd with API_KEY_SALT and compared against the +// comma-separated list of pre-hashed keys in API_KEY_HASHES. Returns 401 if +// the header is missing or the key is unrecognised. +func APIKey(next http.Handler) http.Handler { + salt := []byte(os.Getenv("API_KEY_SALT")) + rawHashes := os.Getenv("API_KEY_HASHES") + + validHashes := make(map[string]struct{}) + for _, h := range strings.Split(rawHashes, ",") { + h = strings.TrimSpace(h) + if h != "" { + validHashes[h] = struct{}{} + } + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Health check is always public. + if r.Method == http.MethodGet && r.URL.Path == "/v1/health" { + next.ServeHTTP(w, r) + return + } + + // Only guard /v1/* and /ws paths. + if !strings.HasPrefix(r.URL.Path, "/v1/") && r.URL.Path != "/ws" { + next.ServeHTTP(w, r) + return + } + + key := r.Header.Get("X-API-Key") + if key == "" { + http.Error(w, "missing X-API-Key header", http.StatusUnauthorized) + return + } + + hashed := hmacSHA256Hex(salt, key) + if _, ok := validHashes[hashed]; !ok { + http.Error(w, "invalid API key", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) +} + +func hmacSHA256Hex(salt []byte, key string) string { + mac := hmac.New(sha256.New, salt) + mac.Write([]byte(key)) + return hex.EncodeToString(mac.Sum(nil)) +} diff --git a/services/api/middleware/auth_test.go b/services/api/middleware/auth_test.go new file mode 100644 index 0000000..66c7d38 --- /dev/null +++ b/services/api/middleware/auth_test.go @@ -0,0 +1,91 @@ +package middleware_test + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Depo-dev/trident/services/api/middleware" +) + +func hashKey(salt, key string) string { + mac := hmac.New(sha256.New, []byte(salt)) + mac.Write([]byte(key)) + return hex.EncodeToString(mac.Sum(nil)) +} + +func okHandler(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) +} + +func TestAPIKey_validKeyPasses(t *testing.T) { + const salt = "testsalt" + const key = "my-secret-key" + + t.Setenv("API_KEY_SALT", salt) + t.Setenv("API_KEY_HASHES", hashKey(salt, key)) + + handler := middleware.APIKey(http.HandlerFunc(okHandler)) + + req := httptest.NewRequest(http.MethodGet, "/v1/events", nil) + req.Header.Set("X-API-Key", key) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } +} + +func TestAPIKey_missingHeader_returns401(t *testing.T) { + t.Setenv("API_KEY_SALT", "testsalt") + t.Setenv("API_KEY_HASHES", "somehash") + + handler := middleware.APIKey(http.HandlerFunc(okHandler)) + + req := httptest.NewRequest(http.MethodGet, "/v1/events", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", rec.Code) + } +} + +func TestAPIKey_invalidKey_returns401(t *testing.T) { + t.Setenv("API_KEY_SALT", "testsalt") + t.Setenv("API_KEY_HASHES", hashKey("testsalt", "correct-key")) + + handler := middleware.APIKey(http.HandlerFunc(okHandler)) + + req := httptest.NewRequest(http.MethodGet, "/v1/events", nil) + req.Header.Set("X-API-Key", "wrong-key") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", rec.Code) + } +} + +func TestAPIKey_healthSkipsAuth(t *testing.T) { + t.Setenv("API_KEY_SALT", "testsalt") + t.Setenv("API_KEY_HASHES", "") // no valid keys + + handler := middleware.APIKey(http.HandlerFunc(okHandler)) + + req := httptest.NewRequest(http.MethodGet, "/v1/health", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 on /v1/health without key, got %d", rec.Code) + } +} diff --git a/services/api/middleware/ratelimit.go b/services/api/middleware/ratelimit.go new file mode 100644 index 0000000..55888f5 --- /dev/null +++ b/services/api/middleware/ratelimit.go @@ -0,0 +1,78 @@ +package middleware + +import ( + "net/http" + "os" + "strconv" + "sync" + + "golang.org/x/time/rate" +) + +type rateLimitStore struct { + mu sync.Mutex + limiters map[string]*rate.Limiter + rps rate.Limit + burst int +} + +func newRateLimitStore(rps rate.Limit, burst int) *rateLimitStore { + return &rateLimitStore{ + limiters: make(map[string]*rate.Limiter), + rps: rps, + burst: burst, + } +} + +func (s *rateLimitStore) get(key string) *rate.Limiter { + s.mu.Lock() + defer s.mu.Unlock() + + if lim, ok := s.limiters[key]; ok { + return lim + } + + lim := rate.NewLimiter(s.rps, s.burst) + s.limiters[key] = lim + return lim +} + +// RateLimit returns an HTTP middleware that enforces a per-API-key token bucket +// limit. Exceeding the limit returns 429 with a Retry-After: 1 header. +// +// Default limits (100 req/s, burst 200) are overridden by RATE_LIMIT_RPS and +// RATE_LIMIT_BURST env vars. +func RateLimit(next http.Handler) http.Handler { + rps := rate.Limit(100) + burst := 200 + + if v := os.Getenv("RATE_LIMIT_RPS"); v != "" { + if n, err := strconv.ParseFloat(v, 64); err == nil { + rps = rate.Limit(n) + } + } + if v := os.Getenv("RATE_LIMIT_BURST"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + burst = n + } + } + + store := newRateLimitStore(rps, burst) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := r.Header.Get("X-API-Key") + if key == "" { + // No key: the auth middleware will handle this; pass through here. + next.ServeHTTP(w, r) + return + } + + if !store.get(key).Allow() { + w.Header().Set("Retry-After", "1") + http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/services/api/middleware/ratelimit_test.go b/services/api/middleware/ratelimit_test.go new file mode 100644 index 0000000..2ceaaa1 --- /dev/null +++ b/services/api/middleware/ratelimit_test.go @@ -0,0 +1,41 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Depo-dev/trident/services/api/middleware" +) + +func TestRateLimit_exceedBurst_returns429(t *testing.T) { + t.Setenv("RATE_LIMIT_RPS", "1") + t.Setenv("RATE_LIMIT_BURST", "2") + + handler := middleware.RateLimit(http.HandlerFunc(okHandler)) + + // First two requests should succeed (burst=2). + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodGet, "/v1/events", nil) + req.Header.Set("X-API-Key", "test-key") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) + } + } + + // Third request exceeds the burst and must be rejected. + req := httptest.NewRequest(http.MethodGet, "/v1/events", nil) + req.Header.Set("X-API-Key", "test-key") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusTooManyRequests { + t.Fatalf("expected 429 after burst exceeded, got %d", rec.Code) + } + + if rec.Header().Get("Retry-After") == "" { + t.Error("expected Retry-After header on 429 response") + } +}