diff --git a/cmd/iron-proxy/main.go b/cmd/iron-proxy/main.go index 238b623..c8c756f 100644 --- a/cmd/iron-proxy/main.go +++ b/cmd/iron-proxy/main.go @@ -83,15 +83,14 @@ func main() { // Managed mode is determined by the presence of a control plane token. managed := proxyToken != "" - if cfg.Management.Listen != "" { - if managed { - fmt.Fprintln(os.Stderr, "error: management.listen cannot be used with managed mode; the control plane is the source of truth") - os.Exit(1) - } - if *configPath == "" { - fmt.Fprintln(os.Stderr, "error: management.listen requires --config; /v1/reload has no file to re-read") - os.Exit(1) - } + // Standalone mode serves /v1/reload, which re-reads the config file. + // Managed mode serves /v1/status and /v1/sync instead: the control plane + // stays the source of truth for config, while the sandbox control plane + // can verify which principal's config the proxy has actually applied + // before routing traffic through it. + if cfg.Management.Listen != "" && !managed && *configPath == "" { + fmt.Fprintln(os.Stderr, "error: management.listen requires --config in standalone mode; /v1/reload has no file to re-read") + os.Exit(1) } // Both modes produce a pipeline holder. Managed mode populates the @@ -121,10 +120,11 @@ func main() { var holder *transform.PipelineHolder var mcpHolder *mcp.PolicyHolder var otelCfg iotel.ExportConfig + var poller *controlplane.Poller if managed { var ingestToken string - holder, mcpHolder, ingestToken, pgListener = initManaged(ctx, cfg, bodyLimits, errc, proxyToken, pgManager, localPgListener, logger) + holder, mcpHolder, ingestToken, pgListener, poller = initManaged(ctx, cfg, bodyLimits, errc, proxyToken, pgManager, localPgListener, logger) if ingestToken != "" { otelCfg.DefaultEndpoint = "https://ingest.iron.sh/v1/logs" otelCfg.DefaultHeaders = map[string]string{ @@ -222,21 +222,32 @@ func main() { Logger: logger, UpstreamResponseHeaderTimeout: time.Duration(cfg.Proxy.UpstreamResponseHeaderTimeout), UpstreamProxy: cfg.Proxy.UpstreamProxy.ProxyFunc(), + // Managed proxies fail closed until the first control-plane config + // has been applied; an un-synced pipeline would otherwise pass + // requests through with placeholder credentials intact. + Ready: managedReady(poller), }) // Initialize metrics server. metricsServer := metrics.New(cfg.Metrics.Listen, logger) - // Initialize management server (standalone mode only; guarded above). + // Initialize management server: /v1/reload in standalone mode, + // /v1/status and /v1/sync in managed mode. var mgmtServer *management.Server if cfg.Management.Listen != "" { - mgmtServer = management.New(management.Options{ + mgmtOpts := management.Options{ Addr: cfg.Management.Listen, APIKey: os.Getenv(cfg.Management.APIKeyEnv), - Reload: newReloadFunc(*configPath, holder, mcpHolder, pgManager, bodyLimits, logger), Logger: logger, Ctx: ctx, - }) + } + if managed { + mgmtOpts.Status = func() any { return poller.Status() } + mgmtOpts.SyncNow = poller.Poke + } else { + mgmtOpts.Reload = newReloadFunc(*configPath, holder, mcpHolder, pgManager, bodyLimits, logger) + } + mgmtServer = management.New(mgmtOpts) } // Start services. @@ -321,7 +332,7 @@ func main() { // // Initial MCP policy preference: control-plane-supplied mcp block first, then // fall back to cfg.MCP from the YAML if the sync did not include one. -func initManaged(ctx context.Context, cfg *config.Config, bodyLimits transform.BodyLimits, errc chan<- error, proxyToken string, pgManager *postgres.Manager, localPgListener *postgres.Listener, logger *slog.Logger) (*transform.PipelineHolder, *mcp.PolicyHolder, string, *postgres.Listener) { +func initManaged(ctx context.Context, cfg *config.Config, bodyLimits transform.BodyLimits, errc chan<- error, proxyToken string, pgManager *postgres.Manager, localPgListener *postgres.Listener, logger *slog.Logger) (*transform.PipelineHolder, *mcp.PolicyHolder, string, *postgres.Listener, *controlplane.Poller) { cpURL := envOrDefault("IRON_CONTROL_PLANE_URL", "https://api.iron.sh") logger.Info("starting in managed mode", slog.String("control_plane_url", cpURL)) @@ -391,22 +402,41 @@ func initManaged(ctx context.Context, cfg *config.Config, bodyLimits transform.B // Start config poller. poller := controlplane.NewPoller(client, configHash, func(u controlplane.SyncUpdate) error { if u.Rules != nil || u.Secrets != nil || u.Transforms != nil { - applyPipelineSync(holder, bodyLimits, logger, u.Rules, u.Secrets, u.Transforms) + if err := applyPipelineSync(holder, bodyLimits, logger, u.Rules, u.Secrets, u.Transforms); err != nil { + return err + } } if u.MCP != nil { - applyMCPSync(mcpHolder, logger, u.MCP) + if err := applyMCPSync(mcpHolder, logger, u.MCP); err != nil { + return err + } } if u.Postgres != nil { - applyPostgresSync(ctx, pgManager, localPgListener, os.Getenv, logger, u.Postgres) + if err := applyPostgresSync(ctx, pgManager, localPgListener, os.Getenv, logger, u.Postgres); err != nil { + return err + } } return nil }, logger) + // Seed the poller's status from the startup sync so /v1/status (and the + // fail-closed gate) reflect it before the polling loop's first pass. + poller.SeedStatus(syncResp) + go func() { errc <- poller.Run(ctx) }() - return holder, mcpHolder, ingestToken, pgListener + return holder, mcpHolder, ingestToken, pgListener, poller +} + +// managedReady gates the proxy on the first applied control-plane config. +// A nil poller (standalone mode) means always ready. +func managedReady(poller *controlplane.Poller) func() bool { + if poller == nil { + return nil + } + return func() bool { return poller.Status().SyncedOnce } } // buildInitialMCPHolder picks the initial MCP policy source: a control-plane @@ -452,20 +482,21 @@ func initStandalone(cfg *config.Config, bodyLimits transform.BodyLimits, logger // swaps it in. If parsing or pipeline construction fails, the existing pipeline // is preserved and an error is logged: an invalid push from the control plane // must not take down the proxy. -func applyPipelineSync(holder *transform.PipelineHolder, bodyLimits transform.BodyLimits, logger *slog.Logger, rules, secrets, transforms json.RawMessage) { +func applyPipelineSync(holder *transform.PipelineHolder, bodyLimits transform.BodyLimits, logger *slog.Logger, rules, secrets, transforms json.RawMessage) error { newTransforms, err := config.TransformsFromSync(rules, secrets, transforms) if err != nil { logger.Error("rejecting invalid pipeline config from sync, keeping current pipeline", slog.String("error", err.Error())) - return + return fmt.Errorf("pipeline sync: %w", err) } newPipeline, err := buildPipeline(newTransforms, bodyLimits, logger) if err != nil { logger.Error("rejecting invalid pipeline config from sync, keeping current pipeline", slog.String("error", err.Error())) - return + return fmt.Errorf("pipeline sync: %w", err) } newPipeline.SetAuditFunc(holder.Load().AuditFunc()) holder.Store(newPipeline) logger.Info("pipeline reloaded", slog.String("transforms", newPipeline.Names())) + return nil } // applyMCPSync compiles a new MCP policy from a sync payload and atomically @@ -473,20 +504,20 @@ func applyPipelineSync(holder *transform.PipelineHolder, bodyLimits transform.Bo // preserved: an invalid push from the control plane must not take down a // running proxy. An empty/null mcp block is interpreted by the caller as // "no update" and is not delivered here. -func applyMCPSync(holder *mcp.PolicyHolder, logger *slog.Logger, raw json.RawMessage) { +func applyMCPSync(holder *mcp.PolicyHolder, logger *slog.Logger, raw json.RawMessage) error { node, present, err := config.MCPFromSync(raw) if err != nil { logger.Error("rejecting invalid mcp policy from sync, keeping current policy", slog.String("error", err.Error())) - return + return fmt.Errorf("mcp sync: %w", err) } if !present { // Should not happen — caller filters absent/null — but treat as no-op. - return + return nil } policy, err := mcp.LoadFromNode(node) if err != nil { logger.Error("rejecting invalid mcp policy from sync, keeping current policy", slog.String("error", err.Error())) - return + return fmt.Errorf("mcp sync: %w", err) } holder.Store(policy) if policy == nil { @@ -494,6 +525,7 @@ func applyMCPSync(holder *mcp.PolicyHolder, logger *slog.Logger, raw json.RawMes } else { logger.Info("mcp policy reloaded") } + return nil } // Environment variables that configure the managed postgres listener when the @@ -585,13 +617,14 @@ func postgresListenerFromSync(local *postgres.Listener, getenv func(string) stri // applyPostgresSync rebuilds the postgres listener from a sync payload and // hot-reloads the manager. An invalid payload is logged and the running // listener is preserved. -func applyPostgresSync(ctx context.Context, mgr *postgres.Manager, local *postgres.Listener, getenv func(string) string, logger *slog.Logger, raw json.RawMessage) { +func applyPostgresSync(ctx context.Context, mgr *postgres.Manager, local *postgres.Listener, getenv func(string) string, logger *slog.Logger, raw json.RawMessage) error { listener, ok := postgresListenerFromSync(local, getenv, logger, raw) if !ok { - return + return fmt.Errorf("postgres sync: invalid postgres config") } mgr.Reload(ctx, listener) logger.Info("postgres listener reloaded from sync", slog.Bool("running", listener != nil)) + return nil } // newReloadFunc returns a management.ReloadFunc that re-reads the YAML config diff --git a/cmd/iron-proxy/main_test.go b/cmd/iron-proxy/main_test.go index 0a0cb69..2795e6b 100644 --- a/cmd/iron-proxy/main_test.go +++ b/cmd/iron-proxy/main_test.go @@ -173,7 +173,7 @@ func TestApplyPostgresSync_ReloadsListener(t *testing.T) { raw := json.RawMessage(`[{"id":"pgs_1","foreign_id":"pg-analytics","database":"analytics","dsn":{"type":"env","var":"PG_DSN"}}]`) - applyPostgresSync(context.Background(), mgr, nil, mapEnv(pgListenerEnv()), discardLogger(), raw) + require.NoError(t, applyPostgresSync(context.Background(), mgr, nil, mapEnv(pgListenerEnv()), discardLogger(), raw)) require.True(t, mgr.Running()) } @@ -187,7 +187,7 @@ func TestApplyPipelineSync_ValidConfig_Swaps(t *testing.T) { logger := slog.New(slog.NewTextHandler(logBuf, nil)) rules := json.RawMessage(`[{"host":"example.com","methods":["GET"],"paths":["/api/*"]}]`) - applyPipelineSync(holder, transform.BodyLimits{}, logger, rules, nil, nil) + require.NoError(t, applyPipelineSync(holder, transform.BodyLimits{}, logger, rules, nil, nil)) require.NotSame(t, original, holder.Load(), "pipeline should have been swapped") require.Equal(t, "allowlist", holder.Load().Names()) @@ -201,7 +201,7 @@ func TestApplyPipelineSync_InvalidJSON_KeepsExistingPipeline(t *testing.T) { logBuf := &bytes.Buffer{} logger := slog.New(slog.NewTextHandler(logBuf, nil)) - applyPipelineSync(holder, transform.BodyLimits{}, logger, json.RawMessage(`{not json`), nil, nil) + require.Error(t, applyPipelineSync(holder, transform.BodyLimits{}, logger, json.RawMessage(`{not json`), nil, nil)) require.Same(t, original, holder.Load(), "pipeline must not be swapped on invalid config") require.Contains(t, logBuf.String(), "rejecting invalid pipeline config") @@ -217,7 +217,7 @@ func TestApplyPipelineSync_InvalidRule_KeepsExistingPipeline(t *testing.T) { // host and cidr are mutually exclusive — rule construction fails. rules := json.RawMessage(`[{"host":"example.com","cidr":"10.0.0.0/8"}]`) - applyPipelineSync(holder, transform.BodyLimits{}, logger, rules, nil, nil) + require.Error(t, applyPipelineSync(holder, transform.BodyLimits{}, logger, rules, nil, nil)) require.Same(t, original, holder.Load(), "pipeline must not be swapped when transform construction fails") require.Contains(t, logBuf.String(), "rejecting invalid pipeline config") @@ -233,7 +233,7 @@ func TestApplyPipelineSync_PreservesAuditFunc(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) rules := json.RawMessage(`[{"host":"example.com"}]`) - applyPipelineSync(holder, transform.BodyLimits{}, logger, rules, nil, nil) + require.NoError(t, applyPipelineSync(holder, transform.BodyLimits{}, logger, rules, nil, nil)) holder.Load().EmitAudit(nil) require.True(t, called, "audit func should be carried over to the new pipeline") diff --git a/internal/config/env.go b/internal/config/env.go index de6e832..8458f80 100644 --- a/internal/config/env.go +++ b/internal/config/env.go @@ -47,6 +47,9 @@ func applyEnvOverrides(cfg *Config) error { if v := os.Getenv("IRON_METRICS_LISTEN"); v != "" { cfg.Metrics.Listen = v } + if v := os.Getenv("IRON_MANAGEMENT_LISTEN"); v != "" { + cfg.Management.Listen = v + } if v := os.Getenv("IRON_LOG_LEVEL"); v != "" { cfg.Log.Level = v } diff --git a/internal/controlplane/client.go b/internal/controlplane/client.go index b2227f4..9f1a17f 100644 --- a/internal/controlplane/client.go +++ b/internal/controlplane/client.go @@ -23,6 +23,11 @@ type SyncResponse struct { MCP json.RawMessage `json:"mcp"` Postgres json.RawMessage `json:"postgres"` IngestToken string `json:"ingest_token"` + // Status and PrincipalID describe the proxy's control-plane assignment as + // of this sync. The control plane includes them only on responses that + // carry a config payload; hash-match responses leave them empty. + Status string `json:"status"` + PrincipalID string `json:"principal_id"` } // Client talks to the iron.sh control plane REST API. Requests are diff --git a/internal/controlplane/poller.go b/internal/controlplane/poller.go index cdb5e4d..dc0a476 100644 --- a/internal/controlplane/poller.go +++ b/internal/controlplane/poller.go @@ -6,6 +6,7 @@ import ( "errors" "log/slog" "math/rand/v2" + "sync" "time" ) @@ -23,12 +24,28 @@ type SyncUpdate struct { Postgres json.RawMessage } +// Status is a snapshot of the poller's applied control-plane state. The +// management API serves it so an operator (or the sandbox control plane) +// can verify which principal's config this proxy is actually enforcing +// before routing traffic through it. +type Status struct { + ConfigHash string `json:"config_hash"` + PrincipalID string `json:"principal_id"` + PrincipalStatus string `json:"principal_status"` + SyncedOnce bool `json:"synced_once"` + LastSyncAt time.Time `json:"last_sync_at"` +} + // Poller periodically calls Sync and applies config updates. type Poller struct { client *Client configHash string onUpdate func(SyncUpdate) error logger *slog.Logger + + mu sync.RWMutex + status Status + poke chan struct{} } // NewPoller creates a new sync poller. @@ -38,12 +55,56 @@ func NewPoller(client *Client, initialConfigHash string, onUpdate func(SyncUpdat configHash: initialConfigHash, onUpdate: onUpdate, logger: logger, + poke: make(chan struct{}, 1), + } +} + +// Poke requests an immediate out-of-band sync. It never blocks: at most one +// poke is queued, and a poke arriving while a sync is in flight coalesces +// into the next loop iteration. +func (p *Poller) Poke() { + select { + case p.poke <- struct{}{}: + default: + } +} + +// Status returns a snapshot of the applied control-plane state. +func (p *Poller) Status() Status { + p.mu.RLock() + defer p.mu.RUnlock() + return p.status +} + +// SeedStatus records the result of a sync performed outside the poller (the +// startup sync in managed mode) so Status reflects it before Run's first +// iteration. +func (p *Poller) SeedStatus(resp *SyncResponse) { + if resp == nil { + return + } + p.recordSync(resp) +} + +func (p *Poller) recordSync(resp *SyncResponse) { + p.mu.Lock() + defer p.mu.Unlock() + p.status.ConfigHash = resp.ConfigHash + p.status.SyncedOnce = true + p.status.LastSyncAt = time.Now().UTC() + // Hash-match responses omit the assignment fields; keep the last known + // values so Status stays meaningful between config changes. + if resp.Status != "" { + p.status.PrincipalStatus = resp.Status + } + if resp.PrincipalID != "" { + p.status.PrincipalID = resp.PrincipalID } } // Run starts the polling loop. It performs an initial sync immediately, then -// polls on PollInterval with ±10% jitter. Returns when ctx is canceled or -// a revocation error is received. +// polls on PollInterval with ±10% jitter; a Poke wakes it early. Returns when +// ctx is canceled or a revocation error is received. func (p *Poller) Run(ctx context.Context) error { if err := p.sync(ctx); err != nil { if isRevocationError(err) { @@ -61,6 +122,8 @@ func (p *Poller) Run(ctx context.Context) error { timer.Stop() return nil case <-timer.C: + case <-p.poke: + timer.Stop() } if err := p.sync(ctx); err != nil { @@ -103,11 +166,13 @@ func (p *Poller) sync(ctx context.Context) error { Postgres: resp.Postgres, }); err != nil { p.logger.Error("applying config update", slog.String("error", err.Error())) + return err } } } p.configHash = resp.ConfigHash + p.recordSync(resp) return nil } diff --git a/internal/controlplane/poller_test.go b/internal/controlplane/poller_test.go index d3804d5..ca982c3 100644 --- a/internal/controlplane/poller_test.go +++ b/internal/controlplane/poller_test.go @@ -3,6 +3,7 @@ package controlplane import ( "context" "encoding/json" + "errors" "io" "log/slog" "net/http" @@ -207,3 +208,159 @@ func TestPollerGracefulShutdown(t *testing.T) { t.Fatal("poller did not shut down in time") } } + +func TestPollerPokeTriggersImmediateSync(t *testing.T) { + var syncCalls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := syncCalls.Add(1) + resp := SyncResponse{ConfigHash: "sha256:one"} + if n > 1 { + resp = SyncResponse{ + ConfigHash: "sha256:two", + Status: "assigned", + PrincipalID: "prn_session", + Secrets: json.RawMessage(`[]`), + } + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := NewClient(server.URL, "irpt_test", testLogger()) + poller := NewPoller(client, "", nil, testLogger()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- poller.Run(ctx) }() + + // The initial sync runs immediately; wait for it. + require.Eventually(t, func() bool { + return poller.Status().SyncedOnce + }, 2*time.Second, 10*time.Millisecond) + require.Equal(t, "sha256:one", poller.Status().ConfigHash) + + // A poke must trigger the second sync long before the 5s poll interval. + poller.Poke() + require.Eventually(t, func() bool { + return poller.Status().ConfigHash == "sha256:two" + }, 2*time.Second, 10*time.Millisecond) + + status := poller.Status() + require.Equal(t, "prn_session", status.PrincipalID) + require.Equal(t, "assigned", status.PrincipalStatus) + require.True(t, status.SyncedOnce) + require.False(t, status.LastSyncAt.IsZero()) + + cancel() + require.NoError(t, <-done) +} + +func TestPollerStatusRetainsPrincipalOnHashMatch(t *testing.T) { + var syncCalls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := syncCalls.Add(1) + resp := SyncResponse{ConfigHash: "sha256:same"} + if n == 1 { + resp.Status = "assigned" + resp.PrincipalID = "prn_keep" + resp.Secrets = json.RawMessage(`[]`) + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := NewClient(server.URL, "irpt_test", testLogger()) + poller := NewPoller(client, "", nil, testLogger()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- poller.Run(ctx) }() + + require.Eventually(t, func() bool { + return poller.Status().PrincipalID == "prn_keep" + }, 2*time.Second, 10*time.Millisecond) + + // Hash-match responses omit the assignment fields; they must be retained. + poller.Poke() + require.Eventually(t, func() bool { + return syncCalls.Load() >= 2 + }, 2*time.Second, 10*time.Millisecond) + require.Equal(t, "prn_keep", poller.Status().PrincipalID) + require.Equal(t, "assigned", poller.Status().PrincipalStatus) + + cancel() + require.NoError(t, <-done) +} + +func TestPollerDoesNotAdvanceStatusWhenApplyFails(t *testing.T) { + var syncCalls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := syncCalls.Add(1) + resp := SyncResponse{ + ConfigHash: "sha256:one", + Status: "assigned", + PrincipalID: "prn_one", + Secrets: json.RawMessage(`[]`), + } + if n > 1 { + resp.ConfigHash = "sha256:two" + resp.PrincipalID = "prn_two" + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := NewClient(server.URL, "irpt_test", testLogger()) + var updates atomic.Int32 + poller := NewPoller(client, "", func(SyncUpdate) error { + if updates.Add(1) > 1 { + return errors.New("apply failed") + } + return nil + }, testLogger()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- poller.Run(ctx) }() + + require.Eventually(t, func() bool { + return poller.Status().PrincipalID == "prn_one" + }, 2*time.Second, 10*time.Millisecond) + + poller.Poke() + require.Eventually(t, func() bool { + return syncCalls.Load() >= 2 + }, 2*time.Second, 10*time.Millisecond) + + status := poller.Status() + require.Equal(t, "sha256:one", status.ConfigHash) + require.Equal(t, "prn_one", status.PrincipalID) + + cancel() + require.NoError(t, <-done) +} + +func TestPollerSeedStatus(t *testing.T) { + client := NewClient("http://127.0.0.1:0", "irpt_test", testLogger()) + poller := NewPoller(client, "", nil, testLogger()) + require.False(t, poller.Status().SyncedOnce) + + poller.SeedStatus(nil) + require.False(t, poller.Status().SyncedOnce) + + poller.SeedStatus(&SyncResponse{ + ConfigHash: "sha256:seed", + Status: "assigned", + PrincipalID: "prn_boot", + }) + status := poller.Status() + require.True(t, status.SyncedOnce) + require.Equal(t, "sha256:seed", status.ConfigHash) + require.Equal(t, "prn_boot", status.PrincipalID) +} diff --git a/internal/management/server.go b/internal/management/server.go index 167bd1f..ac02fd0 100644 --- a/internal/management/server.go +++ b/internal/management/server.go @@ -45,8 +45,16 @@ type statusResponse struct { type Options struct { Addr string APIKey string + // Reload rebuilds the pipeline from the on-disk config. Standalone mode + // only; nil disables /v1/reload (managed proxies have no file to re-read). Reload ReloadFunc - Logger *slog.Logger + // Status returns a JSON-encodable snapshot of the applied control-plane + // state. Managed mode only; nil disables /v1/status. + Status func() any + // SyncNow requests an immediate control-plane sync. Managed mode only; + // nil disables /v1/sync. + SyncNow func() + Logger *slog.Logger // Ctx is the process-scoped context passed to Reload. It must outlive // individual HTTP requests so a client disconnect cannot abort a reload @@ -57,11 +65,13 @@ type Options struct { // Server serves the management API. type Server struct { - server *http.Server - apiKey string - reload ReloadFunc - logger *slog.Logger - ctx context.Context + server *http.Server + apiKey string + reload ReloadFunc + status func() any + syncNow func() + logger *slog.Logger + ctx context.Context } // New creates a Server bound to opts.Addr. The caller starts it with @@ -73,12 +83,16 @@ func New(opts Options) *Server { ctx = context.Background() } s := &Server{ - apiKey: opts.APIKey, - reload: opts.Reload, - logger: opts.Logger, - ctx: ctx, + apiKey: opts.APIKey, + reload: opts.Reload, + status: opts.Status, + syncNow: opts.SyncNow, + logger: opts.Logger, + ctx: ctx, } mux.HandleFunc("/v1/reload", s.handleReload) + mux.HandleFunc("/v1/status", s.handleStatus) + mux.HandleFunc("/v1/sync", s.handleSync) s.server = &http.Server{ Addr: opts.Addr, Handler: mux, @@ -111,6 +125,10 @@ func (s *Server) handleReload(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) return } + if s.reload == nil { + writeJSON(w, http.StatusNotFound, errorResponse{Error: "reload is unavailable in managed mode"}) + return + } // Use the server-scoped context, not r.Context(): reload mutates // process-wide state and must not be cut short by a client disconnect. @@ -132,6 +150,45 @@ func (s *Server) handleReload(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "internal error"}) } +// handleStatus serves the applied control-plane state (managed mode). +func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { + if !s.authorize(r) { + writeJSON(w, http.StatusUnauthorized, errorResponse{Error: "unauthorized"}) + return + } + if r.Method != http.MethodGet { + w.Header().Set("Allow", http.MethodGet) + writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) + return + } + if s.status == nil { + writeJSON(w, http.StatusNotFound, errorResponse{Error: "status is unavailable in standalone mode"}) + return + } + writeJSON(w, http.StatusOK, s.status()) +} + +// handleSync requests an immediate control-plane sync (managed mode). The +// sync itself is asynchronous: callers poll /v1/status to observe the +// applied result. +func (s *Server) handleSync(w http.ResponseWriter, r *http.Request) { + if !s.authorize(r) { + writeJSON(w, http.StatusUnauthorized, errorResponse{Error: "unauthorized"}) + return + } + if r.Method != http.MethodPost { + w.Header().Set("Allow", http.MethodPost) + writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) + return + } + if s.syncNow == nil { + writeJSON(w, http.StatusNotFound, errorResponse{Error: "sync is unavailable in standalone mode"}) + return + } + s.syncNow() + writeJSON(w, http.StatusAccepted, statusResponse{Status: "sync requested"}) +} + func (s *Server) authorize(r *http.Request) bool { const prefix = "Bearer " h := r.Header.Get("Authorization") diff --git a/internal/management/server_test.go b/internal/management/server_test.go index ead6b6c..e6296a6 100644 --- a/internal/management/server_test.go +++ b/internal/management/server_test.go @@ -121,3 +121,76 @@ func TestUnknownPath(t *testing.T) { rec := do(t, s, http.MethodPost, "/anything-else", "Bearer secret") require.Equal(t, http.StatusNotFound, rec.Code) } + +func newManagedTestServer(t *testing.T, key string, status func() any, syncNow func()) *Server { + t.Helper() + return New(Options{ + Addr: "127.0.0.1:0", + APIKey: key, + Status: status, + SyncNow: syncNow, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + }) +} + +func TestStatus_Success(t *testing.T) { + s := newManagedTestServer(t, "secret", func() any { + return map[string]any{"principal_id": "prn_1", "synced_once": true} + }, func() {}) + rec := do(t, s, http.MethodGet, "/v1/status", "Bearer secret") + require.Equal(t, http.StatusOK, rec.Code) + var body map[string]any + require.NoError(t, json.NewDecoder(rec.Body).Decode(&body)) + require.Equal(t, "prn_1", body["principal_id"]) + require.Equal(t, true, body["synced_once"]) +} + +func TestStatus_MissingAuth(t *testing.T) { + s := newManagedTestServer(t, "secret", func() any { return nil }, func() {}) + rec := do(t, s, http.MethodGet, "/v1/status", "") + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestStatus_WrongMethod(t *testing.T) { + s := newManagedTestServer(t, "secret", func() any { return nil }, func() {}) + rec := do(t, s, http.MethodPost, "/v1/status", "Bearer secret") + require.Equal(t, http.StatusMethodNotAllowed, rec.Code) +} + +func TestStatus_StandaloneUnavailable(t *testing.T) { + s := newTestServer(t, "secret", func(context.Context) error { return nil }) + rec := do(t, s, http.MethodGet, "/v1/status", "Bearer secret") + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestSync_Success(t *testing.T) { + var poked bool + s := newManagedTestServer(t, "secret", func() any { return nil }, func() { poked = true }) + rec := do(t, s, http.MethodPost, "/v1/sync", "Bearer secret") + require.Equal(t, http.StatusAccepted, rec.Code) + require.True(t, poked) +} + +func TestSync_MissingAuth(t *testing.T) { + s := newManagedTestServer(t, "secret", func() any { return nil }, func() {}) + rec := do(t, s, http.MethodPost, "/v1/sync", "") + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestSync_WrongMethod(t *testing.T) { + s := newManagedTestServer(t, "secret", func() any { return nil }, func() {}) + rec := do(t, s, http.MethodGet, "/v1/sync", "Bearer secret") + require.Equal(t, http.StatusMethodNotAllowed, rec.Code) +} + +func TestSync_StandaloneUnavailable(t *testing.T) { + s := newTestServer(t, "secret", func(context.Context) error { return nil }) + rec := do(t, s, http.MethodPost, "/v1/sync", "Bearer secret") + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestReload_ManagedUnavailable(t *testing.T) { + s := newManagedTestServer(t, "secret", func() any { return nil }, func() {}) + rec := do(t, s, http.MethodPost, "/v1/reload", "Bearer secret") + require.Equal(t, http.StatusNotFound, rec.Code) +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 1900d38..6bcec6c 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -56,8 +56,11 @@ type Proxy struct { // 443 in production so a client-supplied CONNECT port cannot pivot an // allowlisted hostname onto a different port. Overridable in tests. sniUpstreamPort string + ready func() bool } +const notReadyMessage = "proxy is not ready: awaiting control-plane config" + // Options configures Proxy construction. type Options struct { HTTPAddr string @@ -78,6 +81,12 @@ type Options struct { // an upstream SOCKS5/HTTP CONNECT proxy (see http.Transport.Proxy). nil // means connect directly. Use config.UpstreamProxy.ProxyFunc to build one. UpstreamProxy func(*http.Request) (*url.URL, error) + // Ready, when non-nil, gates request handling: while it returns false + // every proxied request is rejected with 503. Managed proxies use it to + // fail closed until the first control-plane config has been applied, so + // requests can never pass through un-transformed (leaking placeholder + // credentials upstream) during startup. + Ready func() bool } // New creates a new Proxy. In TLSModeMITM, certCache must be non-nil. In @@ -92,6 +101,7 @@ func New(opts Options) *Proxy { guard, _ = dnsguard.New(nil) } p := &Proxy{ + ready: opts.Ready, httpsAddr: opts.HTTPSAddr, tlsMode: opts.TLSMode, tunnelAddr: opts.TunnelAddr, @@ -247,6 +257,31 @@ func (p *Proxy) beginPipelineRun(result *transform.PipelineResult) (*transform.P } } +func (p *Proxy) isReady() bool { + return p.ready == nil || p.ready() +} + +func markNotReady(result *transform.PipelineResult) { + result.Action = transform.ActionReject + result.StatusCode = http.StatusServiceUnavailable + result.RequestTransforms = append(result.RequestTransforms, transform.TransformTrace{ + Name: "ready", + Action: transform.ActionReject, + Annotations: map[string]any{ + "reason": "awaiting_control_plane_config", + }, + }) +} + +func notReadyResponse() *http.Response { + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Status: "503 " + http.StatusText(http.StatusServiceUnavailable), + Header: http.Header{"Content-Type": []string{"text/plain; charset=utf-8"}}, + Body: io.NopCloser(strings.NewReader(notReadyMessage + "\n")), + } +} + func (p *Proxy) handleDirectHTTP(w http.ResponseWriter, r *http.Request) { p.handleHTTP(w, r, nil) } @@ -313,6 +348,12 @@ func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request, tunnelInfo *t pl, finish := p.beginPipelineRun(result) defer finish() + if !p.isReady() { + markNotReady(result) + http.Error(w, notReadyMessage, http.StatusServiceUnavailable) + return + } + bodyLimits := pl.BodyLimits() // Wrap request body for lazy buffering by transforms. r.Body = transform.NewBufferedBody(r.Body, bodyLimits.MaxRequestBodyBytes) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2ca3c48..31151a2 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -19,6 +19,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "testing" "time" @@ -832,3 +833,59 @@ func TestContainsDotSegments(t *testing.T) { }) } } + +func TestHTTPProxy_FailsClosedUntilReady(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + caCert, caKey := generateTestCA(t) + cache, err := certcache.NewFromCA(caCert, caKey, 100, 72*time.Hour) + require.NoError(t, err) + + pipeline := transform.NewPipeline(nil, transform.BodyLimits{}, testLogger()) + audits := make(chan transform.PipelineResult, 2) + pipeline.SetAuditFunc(func(r *transform.PipelineResult) { + audits <- *r + }) + holder := transform.NewPipelineHolder(pipeline) + + var ready atomic.Bool + p := New(Options{ + HTTPAddr: "127.0.0.1:0", + CertCache: cache, + Pipeline: holder, + Logger: testLogger(), + Ready: ready.Load, + }) + + httpLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = p.httpServer.Serve(httpLn) }() + t.Cleanup(func() { _ = p.httpServer.Close() }) + + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/test", httpLn.Addr()), nil) + require.NoError(t, err) + req.Host = upstream.Listener.Addr().String() + + // Not ready: the proxy must reject rather than pass the request through + // an un-synced pipeline. + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + audit := <-audits + require.Equal(t, transform.ActionReject, audit.Action) + require.Equal(t, http.StatusServiceUnavailable, audit.StatusCode) + require.Len(t, audit.RequestTransforms, 1) + require.Equal(t, "ready", audit.RequestTransforms[0].Name) + require.Equal(t, transform.ActionReject, audit.RequestTransforms[0].Action) + + // Ready: the same request flows. + ready.Store(true) + resp, err = http.DefaultClient.Do(req) + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/internal/proxy/sni_passthrough.go b/internal/proxy/sni_passthrough.go index 84d3118..9beeb44 100644 --- a/internal/proxy/sni_passthrough.go +++ b/internal/proxy/sni_passthrough.go @@ -53,6 +53,11 @@ func (p *Proxy) serveSNIPassthrough(clientConn net.Conn) error { pl, finish := p.beginPipelineRun(result) defer finish() + if !p.isReady() { + markNotReady(result) + return nil + } + if sni == "" { result.Action = transform.ActionReject result.StatusCode = http.StatusBadRequest diff --git a/internal/proxy/sni_passthrough_test.go b/internal/proxy/sni_passthrough_test.go index d9cd678..b8ad0ac 100644 --- a/internal/proxy/sni_passthrough_test.go +++ b/internal/proxy/sni_passthrough_test.go @@ -18,6 +18,7 @@ import ( "net/http/httptest" "os" "sync" + "sync/atomic" "testing" "time" @@ -223,6 +224,50 @@ func TestSNIPassthrough_HappyPath(t *testing.T) { require.Equal(t, transform.ActionContinue, results[0].Action) } +func TestSNIPassthrough_FailsClosedUntilReady(t *testing.T) { + upstream, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { _ = upstream.Close() }) + + accepted := make(chan struct{}, 1) + go func() { + conn, err := upstream.Accept() + if err != nil { + return + } + _ = conn.Close() + accepted <- struct{}{} + }() + + _, upstreamPort, err := net.SplitHostPort(upstream.Addr().String()) + require.NoError(t, err) + + var ready atomic.Bool + p, getResults := buildSNIProxy(t, []string{"localhost"}, false) + p.ready = ready.Load + p.sniUpstreamPort = upstreamPort + + proxyAddr := startAcceptLoop(t, func(c net.Conn) { _ = p.serveSNIPassthrough(c) }) + + conn, err := net.Dial("tcp", proxyAddr) + require.NoError(t, err) + tlsConn := tls.Client(conn, &tls.Config{ServerName: "localhost"}) + require.Error(t, tlsConn.Handshake()) + _ = tlsConn.Close() + + require.Eventually(t, func() bool { return len(getResults()) == 1 }, 2*time.Second, 10*time.Millisecond) + results := getResults() + require.Equal(t, transform.ActionReject, results[0].Action) + require.Equal(t, http.StatusServiceUnavailable, results[0].StatusCode) + require.Equal(t, "ready", results[0].RequestTransforms[0].Name) + + select { + case <-accepted: + t.Fatal("sni-only proxy dialed upstream before it was ready") + case <-time.After(100 * time.Millisecond): + } +} + func TestSNIPassthrough_AllowlistDeny(t *testing.T) { upstream, _ := startEchoTLSServer(t) proxyAddr, getResults := startSNIPassthroughProxy(t, []string{"allowed.example"}, upstream) diff --git a/internal/proxy/tunnel.go b/internal/proxy/tunnel.go index 8485300..3bd87d4 100644 --- a/internal/proxy/tunnel.go +++ b/internal/proxy/tunnel.go @@ -286,6 +286,11 @@ func (p *Proxy) tunnelTransformCheck(remoteAddr, target string, connectHeaders h pl, finish := p.beginPipelineRun(result) defer finish() + if !p.isReady() { + markNotReady(result) + return false, notReadyResponse(), nil + } + rejectResp, err := pl.ProcessRequest(req.Context(), tctx, req, &result.RequestTransforms) if err != nil { result.Action = transform.ActionContinue