From ba76e81ec241b3388d23e7f243103ffc2da00fb1 Mon Sep 17 00:00:00 2001 From: Brad Stein Date: Sat, 4 Apr 2026 22:24:56 -0300 Subject: [PATCH] hecate: harden startup recovery and ssh/state self-heal --- cmd/hecate/main.go | 49 ++++--- configs/hecate.example.yaml | 4 + configs/hecate.tethys.yaml | 4 + configs/hecate.titan-db.yaml | 4 + internal/cluster/orchestrator.go | 244 +++++++++++++++++++++++++++++++ internal/config/config.go | 20 +++ internal/service/daemon.go | 34 ++++- internal/sshutil/sshutil.go | 143 ++++++++++++++++++ internal/sshutil/sshutil_test.go | 43 ++++++ internal/state/heal.go | 22 +++ internal/state/intent.go | 5 +- internal/state/intent_test.go | 31 ++++ internal/state/store.go | 5 +- internal/state/store_test.go | 30 ++++ scripts/install.sh | 8 +- 15 files changed, 620 insertions(+), 26 deletions(-) create mode 100644 internal/sshutil/sshutil.go create mode 100644 internal/sshutil/sshutil_test.go create mode 100644 internal/state/heal.go diff --git a/cmd/hecate/main.go b/cmd/hecate/main.go index eae3341..4ca586a 100644 --- a/cmd/hecate/main.go +++ b/cmd/hecate/main.go @@ -19,6 +19,7 @@ import ( "scm.bstein.dev/bstein/hecate/internal/config" "scm.bstein.dev/bstein/hecate/internal/execx" "scm.bstein.dev/bstein/hecate/internal/service" + "scm.bstein.dev/bstein/hecate/internal/sshutil" "scm.bstein.dev/bstein/hecate/internal/state" "scm.bstein.dev/bstein/hecate/internal/ups" ) @@ -442,18 +443,12 @@ func tryPeerBootstrapHandoff(ctx context.Context, cfg config.Config, logger *log attempt := 1 for { cmdArgs := append(append([]string{}, args...), target, remote) - cmd := exec.CommandContext(ctx, "ssh", cmdArgs...) - out, err := cmd.CombinedOutput() + _, err := runSSHWithRecovery(ctx, logger, cfg, cmdArgs, []string{coordinator, host, cfg.SSHJumpHost}) if err == nil { logger.Printf("peer bootstrap handoff succeeded on %s (attempt=%d)", coordinator, attempt) return true, nil } - trimmed := strings.TrimSpace(string(out)) - if trimmed == "" { - logger.Printf("peer bootstrap handoff attempt %d failed for %s: %v", attempt, coordinator, err) - } else { - logger.Printf("peer bootstrap handoff attempt %d failed for %s: %v: %s", attempt, coordinator, err, trimmed) - } + logger.Printf("peer bootstrap handoff attempt %d failed for %s: %v", attempt, coordinator, err) select { case <-ctx.Done(): @@ -487,18 +482,12 @@ func coordinatorAllowsPeerFallbackStartup(ctx context.Context, cfg config.Config } remoteCmd := "sudo -n sh -lc 'if systemctl is-active --quiet hecate-bootstrap.service; then echo __HECATE_BOOTSTRAP_ACTIVE__; else echo __HECATE_BOOTSTRAP_IDLE__; fi; if [ -s /var/lib/hecate/intent.json ]; then cat /var/lib/hecate/intent.json; else echo \"{}\"; fi'" args := append(buildSSHBaseArgs(cfg), target, remoteCmd) - cmd := exec.CommandContext(ctx, "ssh", args...) - out, err := cmd.CombinedOutput() + out, err := runSSHWithRecovery(ctx, logger, cfg, args, []string{coordinator, host, cfg.SSHJumpHost}) if err != nil { - trimmed := strings.TrimSpace(string(out)) - if trimmed == "" { - logger.Printf("warning: coordinator guard check unavailable on %s: %v; allowing peer fallback startup", coordinator, err) - } else { - logger.Printf("warning: coordinator guard check unavailable on %s: %v: %s; allowing peer fallback startup", coordinator, err, trimmed) - } + logger.Printf("warning: coordinator guard check unavailable on %s: %v; allowing peer fallback startup", coordinator, err) return true, "coordinator unreachable", nil } - trimmed := strings.TrimSpace(string(out)) + trimmed := strings.TrimSpace(out) if strings.Contains(trimmed, "__HECATE_BOOTSTRAP_ACTIVE__") { return false, "coordinator bootstrap service is active", nil } @@ -546,6 +535,32 @@ func coordinatorAllowsPeerFallbackStartup(ctx context.Context, cfg config.Config } } +func runSSHWithRecovery(ctx context.Context, logger *log.Logger, cfg config.Config, args []string, repairHosts []string) (string, error) { + try := func() (string, error) { + cmd := exec.CommandContext(ctx, "ssh", args...) + out, err := cmd.CombinedOutput() + trimmed := strings.TrimSpace(string(out)) + if err != nil { + if trimmed == "" { + return "", fmt.Errorf("ssh failed: %w", err) + } + return trimmed, fmt.Errorf("ssh failed: %w: %s", err, trimmed) + } + return trimmed, nil + } + + out, err := try() + if err == nil { + return out, nil + } + if !sshutil.IsHostKeyError(out, err) { + return out, err + } + + sshutil.RepairKnownHosts(ctx, logger, sshutil.KnownHostsFiles(resolveSSHConfigFile(cfg), resolveSSHIdentityFile(cfg)), repairHosts, cfg.SSHPort) + return try() +} + func buildSSHBaseArgs(cfg config.Config) []string { args := []string{ "-o", "BatchMode=yes", diff --git a/configs/hecate.example.yaml b/configs/hecate.example.yaml index 382cb41..a626ac7 100644 --- a/configs/hecate.example.yaml +++ b/configs/hecate.example.yaml @@ -42,6 +42,10 @@ excluded_namespaces: startup: api_wait_seconds: 1200 api_poll_seconds: 2 + require_time_sync: true + time_sync_wait_seconds: 240 + time_sync_poll_seconds: 5 + reconcile_access_on_boot: true auto_etcd_restore_on_api_failure: true etcd_restore_control_plane: titan-0a shutdown: diff --git a/configs/hecate.tethys.yaml b/configs/hecate.tethys.yaml index 03b100b..2f3d0ef 100644 --- a/configs/hecate.tethys.yaml +++ b/configs/hecate.tethys.yaml @@ -108,6 +108,10 @@ excluded_namespaces: startup: api_wait_seconds: 1200 api_poll_seconds: 2 + require_time_sync: true + time_sync_wait_seconds: 240 + time_sync_poll_seconds: 5 + reconcile_access_on_boot: true auto_etcd_restore_on_api_failure: true etcd_restore_control_plane: titan-0a shutdown: diff --git a/configs/hecate.titan-db.yaml b/configs/hecate.titan-db.yaml index 82bb826..dada801 100644 --- a/configs/hecate.titan-db.yaml +++ b/configs/hecate.titan-db.yaml @@ -108,6 +108,10 @@ excluded_namespaces: startup: api_wait_seconds: 1200 api_poll_seconds: 2 + require_time_sync: true + time_sync_wait_seconds: 240 + time_sync_poll_seconds: 5 + reconcile_access_on_boot: true auto_etcd_restore_on_api_failure: true etcd_restore_control_plane: titan-0a shutdown: diff --git a/internal/cluster/orchestrator.go b/internal/cluster/orchestrator.go index f51fb8c..b28409c 100644 --- a/internal/cluster/orchestrator.go +++ b/internal/cluster/orchestrator.go @@ -7,9 +7,12 @@ import ( "errors" "fmt" "log" + "net" + neturl "net/url" "os" "os/exec" "path/filepath" + "regexp" "sort" "strconv" "strings" @@ -18,6 +21,7 @@ import ( "scm.bstein.dev/bstein/hecate/internal/config" "scm.bstein.dev/bstein/hecate/internal/execx" + "scm.bstein.dev/bstein/hecate/internal/sshutil" "scm.bstein.dev/bstein/hecate/internal/state" ) @@ -63,6 +67,8 @@ type workloadScaleSnapshot struct { Entries []workloadScaleEntry `json:"entries"` } +var datastoreEndpointPattern = regexp.MustCompile(`--datastore-endpoint(?:=|\s+)(?:'([^']+)'|"([^"]+)"|([^\s\\]+))`) + var criticalStartupWorkloads = []startupWorkload{ {Namespace: "flux-system", Kind: "deployment", Name: "source-controller"}, {Namespace: "flux-system", Kind: "deployment", Name: "kustomize-controller"}, @@ -119,6 +125,17 @@ func (o *Orchestrator) Startup(ctx context.Context, opts StartupOptions) (err er } o.log.Printf("startup control-planes=%s", strings.Join(o.cfg.ControlPlanes, ",")) + if o.cfg.Startup.RequireTimeSync { + if err := o.waitForTimeSync(ctx, o.cfg.ControlPlanes); err != nil { + return err + } + } + if err := o.preflightExternalDatastore(ctx); err != nil { + return err + } + if o.cfg.Startup.ReconcileAccessOnBoot { + o.bestEffort("reconcile control-plane access", func() error { return o.reconcileNodeAccess(ctx, o.cfg.ControlPlanes) }) + } o.reportFluxSource(ctx, opts.ForceFluxBranch) o.startControlPlanes(ctx, o.cfg.ControlPlanes) @@ -156,6 +173,9 @@ func (o *Orchestrator) Startup(ctx context.Context, opts StartupOptions) (err er return err } o.log.Printf("startup workers=%s", strings.Join(workers, ",")) + if o.cfg.Startup.ReconcileAccessOnBoot { + o.bestEffort("reconcile worker access", func() error { return o.reconcileNodeAccess(ctx, workers) }) + } o.startWorkers(ctx, workers) o.bestEffort("uncordon workers", func() error { return o.uncordonWorkers(ctx, workers) }) @@ -873,6 +893,214 @@ func (o *Orchestrator) waitForAPI(ctx context.Context, attempts int, sleep time. return fmt.Errorf("kubernetes API did not become reachable within timeout") } +func (o *Orchestrator) waitForTimeSync(ctx context.Context, nodes []string) error { + if o.runner.DryRun { + return nil + } + wait := time.Duration(o.cfg.Startup.TimeSyncWaitSeconds) * time.Second + if wait <= 0 { + wait = 240 * time.Second + } + poll := time.Duration(o.cfg.Startup.TimeSyncPollSeconds) * time.Second + if poll <= 0 { + poll = 5 * time.Second + } + deadline := time.Now().Add(wait) + for { + unsynced := []string{} + localOut, localErr := o.run(ctx, 10*time.Second, "sh", "-lc", "timedatectl show -p NTPSynchronized --value 2>/dev/null || echo unknown") + if localErr != nil || !isTimeSynced(localOut) { + if localErr != nil { + unsynced = append(unsynced, fmt.Sprintf("local(%v)", localErr)) + } else { + unsynced = append(unsynced, fmt.Sprintf("local(%s)", strings.TrimSpace(localOut))) + } + } + for _, node := range nodes { + node = strings.TrimSpace(node) + if node == "" { + continue + } + if !o.sshManaged(node) { + continue + } + out, err := o.ssh(ctx, node, "timedatectl show -p NTPSynchronized --value 2>/dev/null || echo unknown") + if err != nil || !isTimeSynced(out) { + if err != nil { + unsynced = append(unsynced, fmt.Sprintf("%s(%v)", node, err)) + } else { + unsynced = append(unsynced, fmt.Sprintf("%s(%s)", node, strings.TrimSpace(out))) + } + } + } + if len(unsynced) == 0 { + return nil + } + if time.Now().After(deadline) { + return fmt.Errorf("startup blocked: time sync not ready within %s (%s)", wait, strings.Join(unsynced, ", ")) + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(poll): + } + } +} + +func isTimeSynced(raw string) bool { + v := strings.ToLower(strings.TrimSpace(raw)) + return v == "yes" || v == "true" || v == "1" +} + +func (o *Orchestrator) preflightExternalDatastore(ctx context.Context) error { + if len(o.cfg.ControlPlanes) == 0 { + return nil + } + controlPlane := strings.TrimSpace(o.cfg.ControlPlanes[0]) + if controlPlane == "" || !o.sshManaged(controlPlane) { + return nil + } + unitOut, err := o.ssh(ctx, controlPlane, "sudo systemctl cat k3s") + if err != nil { + o.log.Printf("warning: external datastore preflight skipped: unable to inspect %s k3s unit: %v", controlPlane, err) + return nil + } + datastoreEndpoint := parseDatastoreEndpoint(unitOut) + if datastoreEndpoint == "" { + return nil + } + u, err := neturl.Parse(datastoreEndpoint) + if err != nil || u.Host == "" { + o.log.Printf("warning: external datastore preflight skipped: unable to parse datastore endpoint %q", datastoreEndpoint) + return nil + } + host := strings.TrimSpace(u.Hostname()) + port := strings.TrimSpace(u.Port()) + if port == "" { + port = "5432" + } + address := net.JoinHostPort(host, port) + if o.tcpReachable(address, 3*time.Second) { + return nil + } + o.log.Printf("warning: datastore endpoint %s is unreachable; attempting software recovery", address) + if node := o.nodeNameForHost(host); node != "" && o.sshManaged(node) { + o.bestEffort("restart datastore service on "+node, func() error { + _, err := o.ssh(ctx, node, "sudo systemctl restart postgresql || sudo systemctl restart postgresql@16-main || sudo systemctl restart postgres") + return err + }) + } + deadline := time.Now().Add(90 * time.Second) + for time.Now().Before(deadline) { + if o.tcpReachable(address, 3*time.Second) { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(3 * time.Second): + } + } + return fmt.Errorf("startup blocked: external datastore endpoint %s remained unreachable after recovery attempt", address) +} + +func parseDatastoreEndpoint(unitText string) string { + if match := datastoreEndpointPattern.FindStringSubmatch(unitText); len(match) == 4 { + for _, candidate := range match[1:] { + value := strings.TrimSpace(candidate) + if value != "" { + return value + } + } + } + + for _, raw := range strings.Split(unitText, "\n") { + line := strings.TrimSpace(raw) + idx := strings.Index(line, "--datastore-endpoint") + if idx < 0 { + continue + } + value := strings.TrimSpace(line[idx+len("--datastore-endpoint"):]) + value = strings.TrimSpace(strings.TrimPrefix(value, "=")) + value = strings.TrimSuffix(strings.TrimSpace(value), "\\") + value = strings.Trim(value, `"'`) + if value != "" { + return value + } + } + return "" +} + +func (o *Orchestrator) nodeNameForHost(host string) string { + host = strings.TrimSpace(host) + if host == "" { + return "" + } + if _, ok := o.cfg.SSHNodeHosts[host]; ok { + return host + } + for node, mapped := range o.cfg.SSHNodeHosts { + if strings.TrimSpace(mapped) == host { + return strings.TrimSpace(node) + } + } + return "" +} + +func (o *Orchestrator) tcpReachable(address string, timeout time.Duration) bool { + conn, err := net.DialTimeout("tcp", address, timeout) + if err != nil { + return false + } + _ = conn.Close() + return true +} + +func (o *Orchestrator) reconcileNodeAccess(ctx context.Context, nodes []string) error { + if len(nodes) == 0 { + return nil + } + parallelism := o.cfg.Shutdown.SSHParallelism + if parallelism <= 0 { + parallelism = 8 + } + if parallelism > len(nodes) { + parallelism = len(nodes) + } + sem := make(chan struct{}, parallelism) + var wg sync.WaitGroup + errCh := make(chan error, len(nodes)) + cmd := `sudo sh -lc 'id atlas >/dev/null 2>&1 || useradd -m -s /bin/bash atlas || true; install -d -m 0755 /etc/sudoers.d; printf "%s\n" "atlas ALL=(ALL) NOPASSWD: /usr/bin/systemctl, /usr/sbin/poweroff, /sbin/poweroff, /usr/local/bin/hecate" > /etc/sudoers.d/90-hecate-atlas; chmod 0440 /etc/sudoers.d/90-hecate-atlas; if command -v visudo >/dev/null 2>&1; then visudo -cf /etc/sudoers.d/90-hecate-atlas >/dev/null; fi'` + for _, node := range nodes { + node := strings.TrimSpace(node) + if node == "" || !o.sshManaged(node) { + continue + } + wg.Add(1) + go func() { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + if _, err := o.ssh(ctx, node, cmd); err != nil { + errCh <- fmt.Errorf("%s: %w", node, err) + } + }() + } + wg.Wait() + close(errCh) + if len(errCh) == 0 { + return nil + } + samples := []string{} + for err := range errCh { + samples = append(samples, err.Error()) + if len(samples) >= 4 { + break + } + } + return fmt.Errorf("access reconcile had %d errors (first: %s)", len(errCh), strings.Join(samples, " | ")) +} + func (o *Orchestrator) fluxSourceReady(ctx context.Context) (bool, error) { out, err := o.kubectl(ctx, 10*time.Second, "-n", "flux-system", "get", "gitrepository", "flux-system", "-o", "jsonpath={.status.conditions[?(@.type==\"Ready\")].status}") if err != nil { @@ -1027,8 +1255,14 @@ func (o *Orchestrator) ssh(ctx context.Context, node string, command string) (st } attempts := make([][]string, 0, 2) attemptNames := make([]string, 0, 2) + knownHostsFiles := sshutil.KnownHostsFiles(sshConfigFile, sshIdentity) + repairHosts := []string{node, host} if o.cfg.SSHJumpHost != "" { jump := o.cfg.SSHJumpHost + repairHosts = append(repairHosts, jump) + if mapped, ok := o.cfg.SSHNodeHosts[jump]; ok && strings.TrimSpace(mapped) != "" { + repairHosts = append(repairHosts, strings.TrimSpace(mapped)) + } if o.cfg.SSHJumpUser != "" { jump = o.cfg.SSHJumpUser + "@" + jump } @@ -1055,6 +1289,16 @@ func (o *Orchestrator) ssh(ctx context.Context, node string, command string) (st } return out, nil } + if sshutil.IsHostKeyError(out, err) { + o.log.Printf("warning: ssh host-key mismatch detected for %s via %s path; repairing known_hosts and retrying once", node, attemptNames[i]) + sshutil.RepairKnownHosts(ctx, o.log, knownHostsFiles, repairHosts, o.cfg.SSHPort) + retryOut, retryErr := o.run(ctx, 45*time.Second, "ssh", args...) + if retryErr == nil { + return retryOut, nil + } + out = retryOut + err = retryErr + } lastOut = out lastErr = err if i < len(attempts)-1 { diff --git a/internal/config/config.go b/internal/config/config.go index 40c1344..5793c86 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -35,6 +35,10 @@ type Config struct { type Startup struct { APIWaitSeconds int `yaml:"api_wait_seconds"` APIPollSeconds int `yaml:"api_poll_seconds"` + RequireTimeSync bool `yaml:"require_time_sync"` + TimeSyncWaitSeconds int `yaml:"time_sync_wait_seconds"` + TimeSyncPollSeconds int `yaml:"time_sync_poll_seconds"` + ReconcileAccessOnBoot bool `yaml:"reconcile_access_on_boot"` AutoEtcdRestoreOnAPIFailure bool `yaml:"auto_etcd_restore_on_api_failure"` EtcdRestoreControlPlane string `yaml:"etcd_restore_control_plane"` } @@ -138,6 +142,12 @@ func (c Config) Validate() error { if c.Startup.APIPollSeconds <= 0 { return fmt.Errorf("config.startup.api_poll_seconds must be > 0") } + if c.Startup.TimeSyncWaitSeconds <= 0 { + return fmt.Errorf("config.startup.time_sync_wait_seconds must be > 0") + } + if c.Startup.TimeSyncPollSeconds <= 0 { + return fmt.Errorf("config.startup.time_sync_poll_seconds must be > 0") + } if c.Startup.EtcdRestoreControlPlane != "" { found := false for _, cp := range c.ControlPlanes { @@ -220,6 +230,10 @@ func defaults() Config { Startup: Startup{ APIWaitSeconds: 1200, APIPollSeconds: 2, + RequireTimeSync: true, + TimeSyncWaitSeconds: 240, + TimeSyncPollSeconds: 5, + ReconcileAccessOnBoot: true, AutoEtcdRestoreOnAPIFailure: true, EtcdRestoreControlPlane: "titan-0a", }, @@ -277,6 +291,12 @@ func (c *Config) applyDefaults() { if c.Startup.APIPollSeconds <= 0 { c.Startup.APIPollSeconds = 2 } + if c.Startup.TimeSyncWaitSeconds <= 0 { + c.Startup.TimeSyncWaitSeconds = 240 + } + if c.Startup.TimeSyncPollSeconds <= 0 { + c.Startup.TimeSyncPollSeconds = 5 + } if c.Startup.EtcdRestoreControlPlane == "" && len(c.ControlPlanes) > 0 { c.Startup.EtcdRestoreControlPlane = c.ControlPlanes[0] } diff --git a/internal/service/daemon.go b/internal/service/daemon.go index 1f98a41..9e7ce8d 100644 --- a/internal/service/daemon.go +++ b/internal/service/daemon.go @@ -15,6 +15,7 @@ import ( "scm.bstein.dev/bstein/hecate/internal/cluster" "scm.bstein.dev/bstein/hecate/internal/config" "scm.bstein.dev/bstein/hecate/internal/metrics" + "scm.bstein.dev/bstein/hecate/internal/sshutil" "scm.bstein.dev/bstein/hecate/internal/state" "scm.bstein.dev/bstein/hecate/internal/ups" ) @@ -240,14 +241,35 @@ func (d *Daemon) forwardShutdown(ctx context.Context, reason string) error { args = append(args, "-J", jump) } args = append(args, target, remoteCmd) - cmd := exec.CommandContext(runCtx, "ssh", args...) - out, err := cmd.CombinedOutput() - if err != nil { + + try := func() (string, error) { + cmd := exec.CommandContext(runCtx, "ssh", args...) + out, err := cmd.CombinedOutput() trimmed := strings.TrimSpace(string(out)) - if trimmed == "" { - return fmt.Errorf("forward shutdown via ssh failed: %w", err) + if err != nil { + if trimmed == "" { + return "", fmt.Errorf("forward shutdown via ssh failed: %w", err) + } + return trimmed, fmt.Errorf("forward shutdown via ssh failed: %w: %s", err, trimmed) } - return fmt.Errorf("forward shutdown via ssh failed: %w: %s", err, trimmed) + return trimmed, nil + } + + out, err := try() + if err != nil && sshutil.IsHostKeyError(out, err) { + repairHosts := []string{d.cfg.Coordination.ForwardShutdownHost, host} + if d.cfg.SSHJumpHost != "" { + repairHosts = append(repairHosts, d.cfg.SSHJumpHost) + } + sshutil.RepairKnownHosts(runCtx, d.log, sshutil.KnownHostsFiles(d.resolveSSHConfigFile(), d.resolveSSHIdentityFile()), repairHosts, d.cfg.SSHPort) + if _, err2 := try(); err2 == nil { + return nil + } else { + return err2 + } + } + if err != nil { + return err } return nil } diff --git a/internal/sshutil/sshutil.go b/internal/sshutil/sshutil.go new file mode 100644 index 0000000..9347a72 --- /dev/null +++ b/internal/sshutil/sshutil.go @@ -0,0 +1,143 @@ +package sshutil + +import ( + "context" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +var hostKeyErrorMarkers = []string{ + "remote host identification has changed", + "host key verification failed", + "offending ", + "possible dns spoofing detected", +} + +func IsHostKeyError(output string, err error) bool { + if err == nil { + return false + } + combined := strings.ToLower(strings.TrimSpace(output + "\n" + err.Error())) + if combined == "" { + return false + } + for _, marker := range hostKeyErrorMarkers { + if strings.Contains(combined, marker) { + return true + } + } + return false +} + +func KnownHostsFiles(sshConfigFile, sshIdentityFile string) []string { + seen := map[string]struct{}{} + add := func(path string) { + p := strings.TrimSpace(path) + if p == "" { + return + } + if _, ok := seen[p]; ok { + return + } + seen[p] = struct{}{} + } + + // Common locations for this environment. + add("/root/.ssh/known_hosts") + add("/home/atlas/.ssh/known_hosts") + add("/home/tethys/.ssh/known_hosts") + + if home, err := os.UserHomeDir(); err == nil && strings.TrimSpace(home) != "" { + add(filepath.Join(home, ".ssh", "known_hosts")) + } + + if cfg := strings.TrimSpace(sshConfigFile); cfg != "" { + add(filepath.Join(filepath.Dir(cfg), "known_hosts")) + } + if key := strings.TrimSpace(sshIdentityFile); key != "" { + add(filepath.Join(filepath.Dir(key), "known_hosts")) + } + + out := make([]string, 0, len(seen)) + for path := range seen { + out = append(out, path) + } + return out +} + +func RepairKnownHosts(ctx context.Context, logger *log.Logger, knownHostsFiles []string, hosts []string, port int) { + if _, err := exec.LookPath("ssh-keygen"); err != nil { + logf(logger, "warning: cannot repair known_hosts (ssh-keygen missing): %v", err) + return + } + + dedupHosts := make([]string, 0, len(hosts)) + hostSet := map[string]struct{}{} + for _, h := range hosts { + host := strings.TrimSpace(h) + if host == "" { + continue + } + if _, ok := hostSet[host]; ok { + continue + } + hostSet[host] = struct{}{} + dedupHosts = append(dedupHosts, host) + } + if len(dedupHosts) == 0 { + return + } + + fileSet := map[string]struct{}{} + for _, f := range knownHostsFiles { + file := strings.TrimSpace(f) + if file == "" { + continue + } + if _, ok := fileSet[file]; ok { + continue + } + fileSet[file] = struct{}{} + } + + for file := range fileSet { + if stat, err := os.Stat(file); err != nil || stat.IsDir() { + continue + } + for _, host := range dedupHosts { + removeKnownHostEntry(ctx, logger, file, host) + if port > 0 { + removeKnownHostEntry(ctx, logger, file, fmt.Sprintf("[%s]:%d", host, port)) + } + } + } +} + +func removeKnownHostEntry(ctx context.Context, logger *log.Logger, file string, entry string) { + runCtx, cancel := context.WithTimeout(ctx, 8*time.Second) + defer cancel() + + cmd := exec.CommandContext(runCtx, "ssh-keygen", "-R", entry, "-f", file) + out, err := cmd.CombinedOutput() + if err == nil { + logf(logger, "known_hosts repaired: removed %s from %s", entry, file) + return + } + trimmed := strings.ToLower(strings.TrimSpace(string(out))) + // ssh-keygen exits non-zero when entry is absent; this is fine. + if strings.Contains(trimmed, "not found in") { + return + } + logf(logger, "warning: known_hosts cleanup failed for %s in %s: %v: %s", entry, file, err, strings.TrimSpace(string(out))) +} + +func logf(logger *log.Logger, format string, args ...any) { + if logger != nil { + logger.Printf(format, args...) + } +} diff --git a/internal/sshutil/sshutil_test.go b/internal/sshutil/sshutil_test.go new file mode 100644 index 0000000..68e9078 --- /dev/null +++ b/internal/sshutil/sshutil_test.go @@ -0,0 +1,43 @@ +package sshutil + +import ( + "errors" + "path/filepath" + "testing" +) + +func TestIsHostKeyErrorDetectsMismatch(t *testing.T) { + out := "WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!" + if !IsHostKeyError(out, errors.New("ssh failed")) { + t.Fatalf("expected host-key mismatch to be detected") + } +} + +func TestIsHostKeyErrorIgnoresGenericFailures(t *testing.T) { + out := "connection timed out" + if IsHostKeyError(out, errors.New("ssh failed")) { + t.Fatalf("did not expect host-key mismatch for generic timeout") + } +} + +func TestKnownHostsFilesIncludesDerivedPaths(t *testing.T) { + configFile := "/home/atlas/.ssh/config" + identityFile := "/home/tethys/.ssh/id_ed25519" + files := KnownHostsFiles(configFile, identityFile) + set := map[string]struct{}{} + for _, f := range files { + set[f] = struct{}{} + } + + want := []string{ + "/home/atlas/.ssh/known_hosts", + "/home/tethys/.ssh/known_hosts", + filepath.Join(filepath.Dir(configFile), "known_hosts"), + filepath.Join(filepath.Dir(identityFile), "known_hosts"), + } + for _, path := range want { + if _, ok := set[path]; !ok { + t.Fatalf("expected known_hosts candidate %s", path) + } + } +} diff --git a/internal/state/heal.go b/internal/state/heal.go new file mode 100644 index 0000000..47d0005 --- /dev/null +++ b/internal/state/heal.go @@ -0,0 +1,22 @@ +package state + +import ( + "fmt" + "os" + "path/filepath" + "time" +) + +func quarantineCorruptFile(path string, payload []byte, replacement []byte, mode os.FileMode) error { + if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil { + return err + } + backup := fmt.Sprintf("%s.corrupt-%s", path, time.Now().UTC().Format("20060102T150405Z")) + if err := os.WriteFile(backup, payload, 0o600); err != nil { + return fmt.Errorf("write backup %s: %w", backup, err) + } + if err := os.WriteFile(path, replacement, mode); err != nil { + return fmt.Errorf("write replacement %s: %w", path, err) + } + return nil +} diff --git a/internal/state/intent.go b/internal/state/intent.go index f479d51..60be58b 100644 --- a/internal/state/intent.go +++ b/internal/state/intent.go @@ -35,7 +35,10 @@ func ReadIntent(path string) (Intent, error) { } var in Intent if err := json.Unmarshal(b, &in); err != nil { - return Intent{}, err + if healErr := quarantineCorruptFile(path, b, []byte("{}\n"), 0o640); healErr != nil { + return Intent{}, fmt.Errorf("decode intent: %w (auto-heal failed: %v)", err, healErr) + } + return Intent{}, nil } return in, nil } diff --git a/internal/state/intent_test.go b/internal/state/intent_test.go index 70886a5..f7c5f33 100644 --- a/internal/state/intent_test.go +++ b/internal/state/intent_test.go @@ -1,6 +1,7 @@ package state import ( + "os" "path/filepath" "testing" ) @@ -28,3 +29,33 @@ func TestMustWriteIntentRejectsUnknownState(t *testing.T) { t.Fatalf("expected invalid state error") } } + +func TestReadIntentAutoHealsCorruptJSON(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "intent.json") + if err := os.WriteFile(p, []byte("{broken"), 0o640); err != nil { + t.Fatalf("write corrupt intent: %v", err) + } + + in, err := ReadIntent(p) + if err != nil { + t.Fatalf("read intent with auto-heal: %v", err) + } + if in.State != "" { + t.Fatalf("expected empty state after heal, got %q", in.State) + } + raw, err := os.ReadFile(p) + if err != nil { + t.Fatalf("read healed intent file: %v", err) + } + if string(raw) != "{}\n" { + t.Fatalf("expected healed intent payload '{}', got %q", string(raw)) + } + matches, err := filepath.Glob(filepath.Join(dir, "intent.json.corrupt-*")) + if err != nil { + t.Fatalf("glob backup files: %v", err) + } + if len(matches) != 1 { + t.Fatalf("expected 1 backup file, got %d (%v)", len(matches), matches) + } +} diff --git a/internal/state/store.go b/internal/state/store.go index e918f47..85b1b0f 100644 --- a/internal/state/store.go +++ b/internal/state/store.go @@ -158,7 +158,10 @@ func (s *Store) loadUnlocked() ([]RunRecord, error) { } var records []RunRecord if err := json.Unmarshal(b, &records); err != nil { - return nil, err + if healErr := quarantineCorruptFile(s.path, b, []byte("[]\n"), 0o640); healErr != nil { + return nil, fmt.Errorf("decode run history: %w (auto-heal failed: %v)", err, healErr) + } + return nil, nil } return records, nil } diff --git a/internal/state/store_test.go b/internal/state/store_test.go index 5a7a826..3eb5dc8 100644 --- a/internal/state/store_test.go +++ b/internal/state/store_test.go @@ -55,3 +55,33 @@ func TestAcquireLockRejectsActiveLock(t *testing.T) { t.Fatalf("expected acquire lock to fail when active pid holds lock") } } + +func TestStoreLoadAutoHealsCorruptJSON(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "runs.json") + if err := os.WriteFile(p, []byte(`{"bad":`), 0o640); err != nil { + t.Fatalf("write corrupt run history: %v", err) + } + + records, err := New(p).Load() + if err != nil { + t.Fatalf("load with auto-heal: %v", err) + } + if len(records) != 0 { + t.Fatalf("expected no records after heal, got %d", len(records)) + } + raw, err := os.ReadFile(p) + if err != nil { + t.Fatalf("read healed runs file: %v", err) + } + if string(raw) != "[]\n" { + t.Fatalf("expected healed runs payload '[]', got %q", string(raw)) + } + matches, err := filepath.Glob(filepath.Join(dir, "runs.json.corrupt-*")) + if err != nil { + t.Fatalf("glob backup files: %v", err) + } + if len(matches) != 1 { + t.Fatalf("expected 1 backup file, got %d (%v)", len(matches), matches) + } +} diff --git a/scripts/install.sh b/scripts/install.sh index bfb30a0..4e5d992 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -199,10 +199,16 @@ migrate_hecate_config() { fi if grep -Eq '^ api_poll_seconds:[[:space:]]*[0-9]+' "${CONF_DIR}/hecate.yaml" \ && ! grep -Eq '^ auto_etcd_restore_on_api_failure:[[:space:]]*(true|false)' "${CONF_DIR}/hecate.yaml"; then - sed -Ei '/^ api_poll_seconds:[[:space:]]*[0-9]+/a\ auto_etcd_restore_on_api_failure: true\n etcd_restore_control_plane: '"${default_restore_cp}"'' "${CONF_DIR}/hecate.yaml" + sed -Ei '/^ api_poll_seconds:[[:space:]]*[0-9]+/a\ require_time_sync: true\n time_sync_wait_seconds: 240\n time_sync_poll_seconds: 5\n reconcile_access_on_boot: true\n auto_etcd_restore_on_api_failure: true\n etcd_restore_control_plane: '"${default_restore_cp}"'' "${CONF_DIR}/hecate.yaml" echo "[install] added startup.auto_etcd_restore_on_api_failure + startup.etcd_restore_control_plane defaults" changed=1 fi + if grep -Eq '^ api_poll_seconds:[[:space:]]*[0-9]+' "${CONF_DIR}/hecate.yaml" \ + && ! grep -Eq '^ require_time_sync:[[:space:]]*(true|false)' "${CONF_DIR}/hecate.yaml"; then + sed -Ei '/^ api_poll_seconds:[[:space:]]*[0-9]+/a\ require_time_sync: true\n time_sync_wait_seconds: 240\n time_sync_poll_seconds: 5\n reconcile_access_on_boot: true' "${CONF_DIR}/hecate.yaml" + echo "[install] added startup time sync + access reconciliation defaults" + changed=1 + fi local role role="$(read_hecate_role)"