hecate: harden startup recovery and ssh/state self-heal
This commit is contained in:
parent
cff88e4944
commit
ba76e81ec2
@ -19,6 +19,7 @@ import (
|
||||
"scm.bstein.dev/bstein/hecate/internal/config"
|
||||
"scm.bstein.dev/bstein/hecate/internal/execx"
|
||||
"scm.bstein.dev/bstein/hecate/internal/service"
|
||||
"scm.bstein.dev/bstein/hecate/internal/sshutil"
|
||||
"scm.bstein.dev/bstein/hecate/internal/state"
|
||||
"scm.bstein.dev/bstein/hecate/internal/ups"
|
||||
)
|
||||
@ -442,18 +443,12 @@ func tryPeerBootstrapHandoff(ctx context.Context, cfg config.Config, logger *log
|
||||
attempt := 1
|
||||
for {
|
||||
cmdArgs := append(append([]string{}, args...), target, remote)
|
||||
cmd := exec.CommandContext(ctx, "ssh", cmdArgs...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
_, err := runSSHWithRecovery(ctx, logger, cfg, cmdArgs, []string{coordinator, host, cfg.SSHJumpHost})
|
||||
if err == nil {
|
||||
logger.Printf("peer bootstrap handoff succeeded on %s (attempt=%d)", coordinator, attempt)
|
||||
return true, nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(string(out))
|
||||
if trimmed == "" {
|
||||
logger.Printf("peer bootstrap handoff attempt %d failed for %s: %v", attempt, coordinator, err)
|
||||
} else {
|
||||
logger.Printf("peer bootstrap handoff attempt %d failed for %s: %v: %s", attempt, coordinator, err, trimmed)
|
||||
}
|
||||
logger.Printf("peer bootstrap handoff attempt %d failed for %s: %v", attempt, coordinator, err)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@ -487,18 +482,12 @@ func coordinatorAllowsPeerFallbackStartup(ctx context.Context, cfg config.Config
|
||||
}
|
||||
remoteCmd := "sudo -n sh -lc 'if systemctl is-active --quiet hecate-bootstrap.service; then echo __HECATE_BOOTSTRAP_ACTIVE__; else echo __HECATE_BOOTSTRAP_IDLE__; fi; if [ -s /var/lib/hecate/intent.json ]; then cat /var/lib/hecate/intent.json; else echo \"{}\"; fi'"
|
||||
args := append(buildSSHBaseArgs(cfg), target, remoteCmd)
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
out, err := runSSHWithRecovery(ctx, logger, cfg, args, []string{coordinator, host, cfg.SSHJumpHost})
|
||||
if err != nil {
|
||||
trimmed := strings.TrimSpace(string(out))
|
||||
if trimmed == "" {
|
||||
logger.Printf("warning: coordinator guard check unavailable on %s: %v; allowing peer fallback startup", coordinator, err)
|
||||
} else {
|
||||
logger.Printf("warning: coordinator guard check unavailable on %s: %v: %s; allowing peer fallback startup", coordinator, err, trimmed)
|
||||
}
|
||||
logger.Printf("warning: coordinator guard check unavailable on %s: %v; allowing peer fallback startup", coordinator, err)
|
||||
return true, "coordinator unreachable", nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(string(out))
|
||||
trimmed := strings.TrimSpace(out)
|
||||
if strings.Contains(trimmed, "__HECATE_BOOTSTRAP_ACTIVE__") {
|
||||
return false, "coordinator bootstrap service is active", nil
|
||||
}
|
||||
@ -546,6 +535,32 @@ func coordinatorAllowsPeerFallbackStartup(ctx context.Context, cfg config.Config
|
||||
}
|
||||
}
|
||||
|
||||
func runSSHWithRecovery(ctx context.Context, logger *log.Logger, cfg config.Config, args []string, repairHosts []string) (string, error) {
|
||||
try := func() (string, error) {
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
trimmed := strings.TrimSpace(string(out))
|
||||
if err != nil {
|
||||
if trimmed == "" {
|
||||
return "", fmt.Errorf("ssh failed: %w", err)
|
||||
}
|
||||
return trimmed, fmt.Errorf("ssh failed: %w: %s", err, trimmed)
|
||||
}
|
||||
return trimmed, nil
|
||||
}
|
||||
|
||||
out, err := try()
|
||||
if err == nil {
|
||||
return out, nil
|
||||
}
|
||||
if !sshutil.IsHostKeyError(out, err) {
|
||||
return out, err
|
||||
}
|
||||
|
||||
sshutil.RepairKnownHosts(ctx, logger, sshutil.KnownHostsFiles(resolveSSHConfigFile(cfg), resolveSSHIdentityFile(cfg)), repairHosts, cfg.SSHPort)
|
||||
return try()
|
||||
}
|
||||
|
||||
func buildSSHBaseArgs(cfg config.Config) []string {
|
||||
args := []string{
|
||||
"-o", "BatchMode=yes",
|
||||
|
||||
@ -42,6 +42,10 @@ excluded_namespaces:
|
||||
startup:
|
||||
api_wait_seconds: 1200
|
||||
api_poll_seconds: 2
|
||||
require_time_sync: true
|
||||
time_sync_wait_seconds: 240
|
||||
time_sync_poll_seconds: 5
|
||||
reconcile_access_on_boot: true
|
||||
auto_etcd_restore_on_api_failure: true
|
||||
etcd_restore_control_plane: titan-0a
|
||||
shutdown:
|
||||
|
||||
@ -108,6 +108,10 @@ excluded_namespaces:
|
||||
startup:
|
||||
api_wait_seconds: 1200
|
||||
api_poll_seconds: 2
|
||||
require_time_sync: true
|
||||
time_sync_wait_seconds: 240
|
||||
time_sync_poll_seconds: 5
|
||||
reconcile_access_on_boot: true
|
||||
auto_etcd_restore_on_api_failure: true
|
||||
etcd_restore_control_plane: titan-0a
|
||||
shutdown:
|
||||
|
||||
@ -108,6 +108,10 @@ excluded_namespaces:
|
||||
startup:
|
||||
api_wait_seconds: 1200
|
||||
api_poll_seconds: 2
|
||||
require_time_sync: true
|
||||
time_sync_wait_seconds: 240
|
||||
time_sync_poll_seconds: 5
|
||||
reconcile_access_on_boot: true
|
||||
auto_etcd_restore_on_api_failure: true
|
||||
etcd_restore_control_plane: titan-0a
|
||||
shutdown:
|
||||
|
||||
@ -7,9 +7,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
neturl "net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -18,6 +21,7 @@ import (
|
||||
|
||||
"scm.bstein.dev/bstein/hecate/internal/config"
|
||||
"scm.bstein.dev/bstein/hecate/internal/execx"
|
||||
"scm.bstein.dev/bstein/hecate/internal/sshutil"
|
||||
"scm.bstein.dev/bstein/hecate/internal/state"
|
||||
)
|
||||
|
||||
@ -63,6 +67,8 @@ type workloadScaleSnapshot struct {
|
||||
Entries []workloadScaleEntry `json:"entries"`
|
||||
}
|
||||
|
||||
var datastoreEndpointPattern = regexp.MustCompile(`--datastore-endpoint(?:=|\s+)(?:'([^']+)'|"([^"]+)"|([^\s\\]+))`)
|
||||
|
||||
var criticalStartupWorkloads = []startupWorkload{
|
||||
{Namespace: "flux-system", Kind: "deployment", Name: "source-controller"},
|
||||
{Namespace: "flux-system", Kind: "deployment", Name: "kustomize-controller"},
|
||||
@ -119,6 +125,17 @@ func (o *Orchestrator) Startup(ctx context.Context, opts StartupOptions) (err er
|
||||
}
|
||||
|
||||
o.log.Printf("startup control-planes=%s", strings.Join(o.cfg.ControlPlanes, ","))
|
||||
if o.cfg.Startup.RequireTimeSync {
|
||||
if err := o.waitForTimeSync(ctx, o.cfg.ControlPlanes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := o.preflightExternalDatastore(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if o.cfg.Startup.ReconcileAccessOnBoot {
|
||||
o.bestEffort("reconcile control-plane access", func() error { return o.reconcileNodeAccess(ctx, o.cfg.ControlPlanes) })
|
||||
}
|
||||
|
||||
o.reportFluxSource(ctx, opts.ForceFluxBranch)
|
||||
o.startControlPlanes(ctx, o.cfg.ControlPlanes)
|
||||
@ -156,6 +173,9 @@ func (o *Orchestrator) Startup(ctx context.Context, opts StartupOptions) (err er
|
||||
return err
|
||||
}
|
||||
o.log.Printf("startup workers=%s", strings.Join(workers, ","))
|
||||
if o.cfg.Startup.ReconcileAccessOnBoot {
|
||||
o.bestEffort("reconcile worker access", func() error { return o.reconcileNodeAccess(ctx, workers) })
|
||||
}
|
||||
o.startWorkers(ctx, workers)
|
||||
o.bestEffort("uncordon workers", func() error { return o.uncordonWorkers(ctx, workers) })
|
||||
|
||||
@ -873,6 +893,214 @@ func (o *Orchestrator) waitForAPI(ctx context.Context, attempts int, sleep time.
|
||||
return fmt.Errorf("kubernetes API did not become reachable within timeout")
|
||||
}
|
||||
|
||||
func (o *Orchestrator) waitForTimeSync(ctx context.Context, nodes []string) error {
|
||||
if o.runner.DryRun {
|
||||
return nil
|
||||
}
|
||||
wait := time.Duration(o.cfg.Startup.TimeSyncWaitSeconds) * time.Second
|
||||
if wait <= 0 {
|
||||
wait = 240 * time.Second
|
||||
}
|
||||
poll := time.Duration(o.cfg.Startup.TimeSyncPollSeconds) * time.Second
|
||||
if poll <= 0 {
|
||||
poll = 5 * time.Second
|
||||
}
|
||||
deadline := time.Now().Add(wait)
|
||||
for {
|
||||
unsynced := []string{}
|
||||
localOut, localErr := o.run(ctx, 10*time.Second, "sh", "-lc", "timedatectl show -p NTPSynchronized --value 2>/dev/null || echo unknown")
|
||||
if localErr != nil || !isTimeSynced(localOut) {
|
||||
if localErr != nil {
|
||||
unsynced = append(unsynced, fmt.Sprintf("local(%v)", localErr))
|
||||
} else {
|
||||
unsynced = append(unsynced, fmt.Sprintf("local(%s)", strings.TrimSpace(localOut)))
|
||||
}
|
||||
}
|
||||
for _, node := range nodes {
|
||||
node = strings.TrimSpace(node)
|
||||
if node == "" {
|
||||
continue
|
||||
}
|
||||
if !o.sshManaged(node) {
|
||||
continue
|
||||
}
|
||||
out, err := o.ssh(ctx, node, "timedatectl show -p NTPSynchronized --value 2>/dev/null || echo unknown")
|
||||
if err != nil || !isTimeSynced(out) {
|
||||
if err != nil {
|
||||
unsynced = append(unsynced, fmt.Sprintf("%s(%v)", node, err))
|
||||
} else {
|
||||
unsynced = append(unsynced, fmt.Sprintf("%s(%s)", node, strings.TrimSpace(out)))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(unsynced) == 0 {
|
||||
return nil
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return fmt.Errorf("startup blocked: time sync not ready within %s (%s)", wait, strings.Join(unsynced, ", "))
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(poll):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isTimeSynced(raw string) bool {
|
||||
v := strings.ToLower(strings.TrimSpace(raw))
|
||||
return v == "yes" || v == "true" || v == "1"
|
||||
}
|
||||
|
||||
func (o *Orchestrator) preflightExternalDatastore(ctx context.Context) error {
|
||||
if len(o.cfg.ControlPlanes) == 0 {
|
||||
return nil
|
||||
}
|
||||
controlPlane := strings.TrimSpace(o.cfg.ControlPlanes[0])
|
||||
if controlPlane == "" || !o.sshManaged(controlPlane) {
|
||||
return nil
|
||||
}
|
||||
unitOut, err := o.ssh(ctx, controlPlane, "sudo systemctl cat k3s")
|
||||
if err != nil {
|
||||
o.log.Printf("warning: external datastore preflight skipped: unable to inspect %s k3s unit: %v", controlPlane, err)
|
||||
return nil
|
||||
}
|
||||
datastoreEndpoint := parseDatastoreEndpoint(unitOut)
|
||||
if datastoreEndpoint == "" {
|
||||
return nil
|
||||
}
|
||||
u, err := neturl.Parse(datastoreEndpoint)
|
||||
if err != nil || u.Host == "" {
|
||||
o.log.Printf("warning: external datastore preflight skipped: unable to parse datastore endpoint %q", datastoreEndpoint)
|
||||
return nil
|
||||
}
|
||||
host := strings.TrimSpace(u.Hostname())
|
||||
port := strings.TrimSpace(u.Port())
|
||||
if port == "" {
|
||||
port = "5432"
|
||||
}
|
||||
address := net.JoinHostPort(host, port)
|
||||
if o.tcpReachable(address, 3*time.Second) {
|
||||
return nil
|
||||
}
|
||||
o.log.Printf("warning: datastore endpoint %s is unreachable; attempting software recovery", address)
|
||||
if node := o.nodeNameForHost(host); node != "" && o.sshManaged(node) {
|
||||
o.bestEffort("restart datastore service on "+node, func() error {
|
||||
_, err := o.ssh(ctx, node, "sudo systemctl restart postgresql || sudo systemctl restart postgresql@16-main || sudo systemctl restart postgres")
|
||||
return err
|
||||
})
|
||||
}
|
||||
deadline := time.Now().Add(90 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if o.tcpReachable(address, 3*time.Second) {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(3 * time.Second):
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("startup blocked: external datastore endpoint %s remained unreachable after recovery attempt", address)
|
||||
}
|
||||
|
||||
func parseDatastoreEndpoint(unitText string) string {
|
||||
if match := datastoreEndpointPattern.FindStringSubmatch(unitText); len(match) == 4 {
|
||||
for _, candidate := range match[1:] {
|
||||
value := strings.TrimSpace(candidate)
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, raw := range strings.Split(unitText, "\n") {
|
||||
line := strings.TrimSpace(raw)
|
||||
idx := strings.Index(line, "--datastore-endpoint")
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
value := strings.TrimSpace(line[idx+len("--datastore-endpoint"):])
|
||||
value = strings.TrimSpace(strings.TrimPrefix(value, "="))
|
||||
value = strings.TrimSuffix(strings.TrimSpace(value), "\\")
|
||||
value = strings.Trim(value, `"'`)
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (o *Orchestrator) nodeNameForHost(host string) string {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return ""
|
||||
}
|
||||
if _, ok := o.cfg.SSHNodeHosts[host]; ok {
|
||||
return host
|
||||
}
|
||||
for node, mapped := range o.cfg.SSHNodeHosts {
|
||||
if strings.TrimSpace(mapped) == host {
|
||||
return strings.TrimSpace(node)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (o *Orchestrator) tcpReachable(address string, timeout time.Duration) bool {
|
||||
conn, err := net.DialTimeout("tcp", address, timeout)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *Orchestrator) reconcileNodeAccess(ctx context.Context, nodes []string) error {
|
||||
if len(nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
parallelism := o.cfg.Shutdown.SSHParallelism
|
||||
if parallelism <= 0 {
|
||||
parallelism = 8
|
||||
}
|
||||
if parallelism > len(nodes) {
|
||||
parallelism = len(nodes)
|
||||
}
|
||||
sem := make(chan struct{}, parallelism)
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, len(nodes))
|
||||
cmd := `sudo sh -lc 'id atlas >/dev/null 2>&1 || useradd -m -s /bin/bash atlas || true; install -d -m 0755 /etc/sudoers.d; printf "%s\n" "atlas ALL=(ALL) NOPASSWD: /usr/bin/systemctl, /usr/sbin/poweroff, /sbin/poweroff, /usr/local/bin/hecate" > /etc/sudoers.d/90-hecate-atlas; chmod 0440 /etc/sudoers.d/90-hecate-atlas; if command -v visudo >/dev/null 2>&1; then visudo -cf /etc/sudoers.d/90-hecate-atlas >/dev/null; fi'`
|
||||
for _, node := range nodes {
|
||||
node := strings.TrimSpace(node)
|
||||
if node == "" || !o.sshManaged(node) {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
if _, err := o.ssh(ctx, node, cmd); err != nil {
|
||||
errCh <- fmt.Errorf("%s: %w", node, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
if len(errCh) == 0 {
|
||||
return nil
|
||||
}
|
||||
samples := []string{}
|
||||
for err := range errCh {
|
||||
samples = append(samples, err.Error())
|
||||
if len(samples) >= 4 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("access reconcile had %d errors (first: %s)", len(errCh), strings.Join(samples, " | "))
|
||||
}
|
||||
|
||||
func (o *Orchestrator) fluxSourceReady(ctx context.Context) (bool, error) {
|
||||
out, err := o.kubectl(ctx, 10*time.Second, "-n", "flux-system", "get", "gitrepository", "flux-system", "-o", "jsonpath={.status.conditions[?(@.type==\"Ready\")].status}")
|
||||
if err != nil {
|
||||
@ -1027,8 +1255,14 @@ func (o *Orchestrator) ssh(ctx context.Context, node string, command string) (st
|
||||
}
|
||||
attempts := make([][]string, 0, 2)
|
||||
attemptNames := make([]string, 0, 2)
|
||||
knownHostsFiles := sshutil.KnownHostsFiles(sshConfigFile, sshIdentity)
|
||||
repairHosts := []string{node, host}
|
||||
if o.cfg.SSHJumpHost != "" {
|
||||
jump := o.cfg.SSHJumpHost
|
||||
repairHosts = append(repairHosts, jump)
|
||||
if mapped, ok := o.cfg.SSHNodeHosts[jump]; ok && strings.TrimSpace(mapped) != "" {
|
||||
repairHosts = append(repairHosts, strings.TrimSpace(mapped))
|
||||
}
|
||||
if o.cfg.SSHJumpUser != "" {
|
||||
jump = o.cfg.SSHJumpUser + "@" + jump
|
||||
}
|
||||
@ -1055,6 +1289,16 @@ func (o *Orchestrator) ssh(ctx context.Context, node string, command string) (st
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
if sshutil.IsHostKeyError(out, err) {
|
||||
o.log.Printf("warning: ssh host-key mismatch detected for %s via %s path; repairing known_hosts and retrying once", node, attemptNames[i])
|
||||
sshutil.RepairKnownHosts(ctx, o.log, knownHostsFiles, repairHosts, o.cfg.SSHPort)
|
||||
retryOut, retryErr := o.run(ctx, 45*time.Second, "ssh", args...)
|
||||
if retryErr == nil {
|
||||
return retryOut, nil
|
||||
}
|
||||
out = retryOut
|
||||
err = retryErr
|
||||
}
|
||||
lastOut = out
|
||||
lastErr = err
|
||||
if i < len(attempts)-1 {
|
||||
|
||||
@ -35,6 +35,10 @@ type Config struct {
|
||||
type Startup struct {
|
||||
APIWaitSeconds int `yaml:"api_wait_seconds"`
|
||||
APIPollSeconds int `yaml:"api_poll_seconds"`
|
||||
RequireTimeSync bool `yaml:"require_time_sync"`
|
||||
TimeSyncWaitSeconds int `yaml:"time_sync_wait_seconds"`
|
||||
TimeSyncPollSeconds int `yaml:"time_sync_poll_seconds"`
|
||||
ReconcileAccessOnBoot bool `yaml:"reconcile_access_on_boot"`
|
||||
AutoEtcdRestoreOnAPIFailure bool `yaml:"auto_etcd_restore_on_api_failure"`
|
||||
EtcdRestoreControlPlane string `yaml:"etcd_restore_control_plane"`
|
||||
}
|
||||
@ -138,6 +142,12 @@ func (c Config) Validate() error {
|
||||
if c.Startup.APIPollSeconds <= 0 {
|
||||
return fmt.Errorf("config.startup.api_poll_seconds must be > 0")
|
||||
}
|
||||
if c.Startup.TimeSyncWaitSeconds <= 0 {
|
||||
return fmt.Errorf("config.startup.time_sync_wait_seconds must be > 0")
|
||||
}
|
||||
if c.Startup.TimeSyncPollSeconds <= 0 {
|
||||
return fmt.Errorf("config.startup.time_sync_poll_seconds must be > 0")
|
||||
}
|
||||
if c.Startup.EtcdRestoreControlPlane != "" {
|
||||
found := false
|
||||
for _, cp := range c.ControlPlanes {
|
||||
@ -220,6 +230,10 @@ func defaults() Config {
|
||||
Startup: Startup{
|
||||
APIWaitSeconds: 1200,
|
||||
APIPollSeconds: 2,
|
||||
RequireTimeSync: true,
|
||||
TimeSyncWaitSeconds: 240,
|
||||
TimeSyncPollSeconds: 5,
|
||||
ReconcileAccessOnBoot: true,
|
||||
AutoEtcdRestoreOnAPIFailure: true,
|
||||
EtcdRestoreControlPlane: "titan-0a",
|
||||
},
|
||||
@ -277,6 +291,12 @@ func (c *Config) applyDefaults() {
|
||||
if c.Startup.APIPollSeconds <= 0 {
|
||||
c.Startup.APIPollSeconds = 2
|
||||
}
|
||||
if c.Startup.TimeSyncWaitSeconds <= 0 {
|
||||
c.Startup.TimeSyncWaitSeconds = 240
|
||||
}
|
||||
if c.Startup.TimeSyncPollSeconds <= 0 {
|
||||
c.Startup.TimeSyncPollSeconds = 5
|
||||
}
|
||||
if c.Startup.EtcdRestoreControlPlane == "" && len(c.ControlPlanes) > 0 {
|
||||
c.Startup.EtcdRestoreControlPlane = c.ControlPlanes[0]
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"scm.bstein.dev/bstein/hecate/internal/cluster"
|
||||
"scm.bstein.dev/bstein/hecate/internal/config"
|
||||
"scm.bstein.dev/bstein/hecate/internal/metrics"
|
||||
"scm.bstein.dev/bstein/hecate/internal/sshutil"
|
||||
"scm.bstein.dev/bstein/hecate/internal/state"
|
||||
"scm.bstein.dev/bstein/hecate/internal/ups"
|
||||
)
|
||||
@ -240,14 +241,35 @@ func (d *Daemon) forwardShutdown(ctx context.Context, reason string) error {
|
||||
args = append(args, "-J", jump)
|
||||
}
|
||||
args = append(args, target, remoteCmd)
|
||||
cmd := exec.CommandContext(runCtx, "ssh", args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
|
||||
try := func() (string, error) {
|
||||
cmd := exec.CommandContext(runCtx, "ssh", args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
trimmed := strings.TrimSpace(string(out))
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("forward shutdown via ssh failed: %w", err)
|
||||
if err != nil {
|
||||
if trimmed == "" {
|
||||
return "", fmt.Errorf("forward shutdown via ssh failed: %w", err)
|
||||
}
|
||||
return trimmed, fmt.Errorf("forward shutdown via ssh failed: %w: %s", err, trimmed)
|
||||
}
|
||||
return fmt.Errorf("forward shutdown via ssh failed: %w: %s", err, trimmed)
|
||||
return trimmed, nil
|
||||
}
|
||||
|
||||
out, err := try()
|
||||
if err != nil && sshutil.IsHostKeyError(out, err) {
|
||||
repairHosts := []string{d.cfg.Coordination.ForwardShutdownHost, host}
|
||||
if d.cfg.SSHJumpHost != "" {
|
||||
repairHosts = append(repairHosts, d.cfg.SSHJumpHost)
|
||||
}
|
||||
sshutil.RepairKnownHosts(runCtx, d.log, sshutil.KnownHostsFiles(d.resolveSSHConfigFile(), d.resolveSSHIdentityFile()), repairHosts, d.cfg.SSHPort)
|
||||
if _, err2 := try(); err2 == nil {
|
||||
return nil
|
||||
} else {
|
||||
return err2
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
143
internal/sshutil/sshutil.go
Normal file
143
internal/sshutil/sshutil.go
Normal file
@ -0,0 +1,143 @@
|
||||
package sshutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var hostKeyErrorMarkers = []string{
|
||||
"remote host identification has changed",
|
||||
"host key verification failed",
|
||||
"offending ",
|
||||
"possible dns spoofing detected",
|
||||
}
|
||||
|
||||
func IsHostKeyError(output string, err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
combined := strings.ToLower(strings.TrimSpace(output + "\n" + err.Error()))
|
||||
if combined == "" {
|
||||
return false
|
||||
}
|
||||
for _, marker := range hostKeyErrorMarkers {
|
||||
if strings.Contains(combined, marker) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func KnownHostsFiles(sshConfigFile, sshIdentityFile string) []string {
|
||||
seen := map[string]struct{}{}
|
||||
add := func(path string) {
|
||||
p := strings.TrimSpace(path)
|
||||
if p == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[p]; ok {
|
||||
return
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
}
|
||||
|
||||
// Common locations for this environment.
|
||||
add("/root/.ssh/known_hosts")
|
||||
add("/home/atlas/.ssh/known_hosts")
|
||||
add("/home/tethys/.ssh/known_hosts")
|
||||
|
||||
if home, err := os.UserHomeDir(); err == nil && strings.TrimSpace(home) != "" {
|
||||
add(filepath.Join(home, ".ssh", "known_hosts"))
|
||||
}
|
||||
|
||||
if cfg := strings.TrimSpace(sshConfigFile); cfg != "" {
|
||||
add(filepath.Join(filepath.Dir(cfg), "known_hosts"))
|
||||
}
|
||||
if key := strings.TrimSpace(sshIdentityFile); key != "" {
|
||||
add(filepath.Join(filepath.Dir(key), "known_hosts"))
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(seen))
|
||||
for path := range seen {
|
||||
out = append(out, path)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func RepairKnownHosts(ctx context.Context, logger *log.Logger, knownHostsFiles []string, hosts []string, port int) {
|
||||
if _, err := exec.LookPath("ssh-keygen"); err != nil {
|
||||
logf(logger, "warning: cannot repair known_hosts (ssh-keygen missing): %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dedupHosts := make([]string, 0, len(hosts))
|
||||
hostSet := map[string]struct{}{}
|
||||
for _, h := range hosts {
|
||||
host := strings.TrimSpace(h)
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := hostSet[host]; ok {
|
||||
continue
|
||||
}
|
||||
hostSet[host] = struct{}{}
|
||||
dedupHosts = append(dedupHosts, host)
|
||||
}
|
||||
if len(dedupHosts) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fileSet := map[string]struct{}{}
|
||||
for _, f := range knownHostsFiles {
|
||||
file := strings.TrimSpace(f)
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := fileSet[file]; ok {
|
||||
continue
|
||||
}
|
||||
fileSet[file] = struct{}{}
|
||||
}
|
||||
|
||||
for file := range fileSet {
|
||||
if stat, err := os.Stat(file); err != nil || stat.IsDir() {
|
||||
continue
|
||||
}
|
||||
for _, host := range dedupHosts {
|
||||
removeKnownHostEntry(ctx, logger, file, host)
|
||||
if port > 0 {
|
||||
removeKnownHostEntry(ctx, logger, file, fmt.Sprintf("[%s]:%d", host, port))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func removeKnownHostEntry(ctx context.Context, logger *log.Logger, file string, entry string) {
|
||||
runCtx, cancel := context.WithTimeout(ctx, 8*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(runCtx, "ssh-keygen", "-R", entry, "-f", file)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err == nil {
|
||||
logf(logger, "known_hosts repaired: removed %s from %s", entry, file)
|
||||
return
|
||||
}
|
||||
trimmed := strings.ToLower(strings.TrimSpace(string(out)))
|
||||
// ssh-keygen exits non-zero when entry is absent; this is fine.
|
||||
if strings.Contains(trimmed, "not found in") {
|
||||
return
|
||||
}
|
||||
logf(logger, "warning: known_hosts cleanup failed for %s in %s: %v: %s", entry, file, err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
|
||||
func logf(logger *log.Logger, format string, args ...any) {
|
||||
if logger != nil {
|
||||
logger.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
43
internal/sshutil/sshutil_test.go
Normal file
43
internal/sshutil/sshutil_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
package sshutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsHostKeyErrorDetectsMismatch(t *testing.T) {
|
||||
out := "WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!"
|
||||
if !IsHostKeyError(out, errors.New("ssh failed")) {
|
||||
t.Fatalf("expected host-key mismatch to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsHostKeyErrorIgnoresGenericFailures(t *testing.T) {
|
||||
out := "connection timed out"
|
||||
if IsHostKeyError(out, errors.New("ssh failed")) {
|
||||
t.Fatalf("did not expect host-key mismatch for generic timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKnownHostsFilesIncludesDerivedPaths(t *testing.T) {
|
||||
configFile := "/home/atlas/.ssh/config"
|
||||
identityFile := "/home/tethys/.ssh/id_ed25519"
|
||||
files := KnownHostsFiles(configFile, identityFile)
|
||||
set := map[string]struct{}{}
|
||||
for _, f := range files {
|
||||
set[f] = struct{}{}
|
||||
}
|
||||
|
||||
want := []string{
|
||||
"/home/atlas/.ssh/known_hosts",
|
||||
"/home/tethys/.ssh/known_hosts",
|
||||
filepath.Join(filepath.Dir(configFile), "known_hosts"),
|
||||
filepath.Join(filepath.Dir(identityFile), "known_hosts"),
|
||||
}
|
||||
for _, path := range want {
|
||||
if _, ok := set[path]; !ok {
|
||||
t.Fatalf("expected known_hosts candidate %s", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
22
internal/state/heal.go
Normal file
22
internal/state/heal.go
Normal file
@ -0,0 +1,22 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
func quarantineCorruptFile(path string, payload []byte, replacement []byte, mode os.FileMode) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
backup := fmt.Sprintf("%s.corrupt-%s", path, time.Now().UTC().Format("20060102T150405Z"))
|
||||
if err := os.WriteFile(backup, payload, 0o600); err != nil {
|
||||
return fmt.Errorf("write backup %s: %w", backup, err)
|
||||
}
|
||||
if err := os.WriteFile(path, replacement, mode); err != nil {
|
||||
return fmt.Errorf("write replacement %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -35,7 +35,10 @@ func ReadIntent(path string) (Intent, error) {
|
||||
}
|
||||
var in Intent
|
||||
if err := json.Unmarshal(b, &in); err != nil {
|
||||
return Intent{}, err
|
||||
if healErr := quarantineCorruptFile(path, b, []byte("{}\n"), 0o640); healErr != nil {
|
||||
return Intent{}, fmt.Errorf("decode intent: %w (auto-heal failed: %v)", err, healErr)
|
||||
}
|
||||
return Intent{}, nil
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
@ -28,3 +29,33 @@ func TestMustWriteIntentRejectsUnknownState(t *testing.T) {
|
||||
t.Fatalf("expected invalid state error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadIntentAutoHealsCorruptJSON(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
p := filepath.Join(dir, "intent.json")
|
||||
if err := os.WriteFile(p, []byte("{broken"), 0o640); err != nil {
|
||||
t.Fatalf("write corrupt intent: %v", err)
|
||||
}
|
||||
|
||||
in, err := ReadIntent(p)
|
||||
if err != nil {
|
||||
t.Fatalf("read intent with auto-heal: %v", err)
|
||||
}
|
||||
if in.State != "" {
|
||||
t.Fatalf("expected empty state after heal, got %q", in.State)
|
||||
}
|
||||
raw, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
t.Fatalf("read healed intent file: %v", err)
|
||||
}
|
||||
if string(raw) != "{}\n" {
|
||||
t.Fatalf("expected healed intent payload '{}', got %q", string(raw))
|
||||
}
|
||||
matches, err := filepath.Glob(filepath.Join(dir, "intent.json.corrupt-*"))
|
||||
if err != nil {
|
||||
t.Fatalf("glob backup files: %v", err)
|
||||
}
|
||||
if len(matches) != 1 {
|
||||
t.Fatalf("expected 1 backup file, got %d (%v)", len(matches), matches)
|
||||
}
|
||||
}
|
||||
|
||||
@ -158,7 +158,10 @@ func (s *Store) loadUnlocked() ([]RunRecord, error) {
|
||||
}
|
||||
var records []RunRecord
|
||||
if err := json.Unmarshal(b, &records); err != nil {
|
||||
return nil, err
|
||||
if healErr := quarantineCorruptFile(s.path, b, []byte("[]\n"), 0o640); healErr != nil {
|
||||
return nil, fmt.Errorf("decode run history: %w (auto-heal failed: %v)", err, healErr)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
@ -55,3 +55,33 @@ func TestAcquireLockRejectsActiveLock(t *testing.T) {
|
||||
t.Fatalf("expected acquire lock to fail when active pid holds lock")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreLoadAutoHealsCorruptJSON(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
p := filepath.Join(dir, "runs.json")
|
||||
if err := os.WriteFile(p, []byte(`{"bad":`), 0o640); err != nil {
|
||||
t.Fatalf("write corrupt run history: %v", err)
|
||||
}
|
||||
|
||||
records, err := New(p).Load()
|
||||
if err != nil {
|
||||
t.Fatalf("load with auto-heal: %v", err)
|
||||
}
|
||||
if len(records) != 0 {
|
||||
t.Fatalf("expected no records after heal, got %d", len(records))
|
||||
}
|
||||
raw, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
t.Fatalf("read healed runs file: %v", err)
|
||||
}
|
||||
if string(raw) != "[]\n" {
|
||||
t.Fatalf("expected healed runs payload '[]', got %q", string(raw))
|
||||
}
|
||||
matches, err := filepath.Glob(filepath.Join(dir, "runs.json.corrupt-*"))
|
||||
if err != nil {
|
||||
t.Fatalf("glob backup files: %v", err)
|
||||
}
|
||||
if len(matches) != 1 {
|
||||
t.Fatalf("expected 1 backup file, got %d (%v)", len(matches), matches)
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,10 +199,16 @@ migrate_hecate_config() {
|
||||
fi
|
||||
if grep -Eq '^ api_poll_seconds:[[:space:]]*[0-9]+' "${CONF_DIR}/hecate.yaml" \
|
||||
&& ! grep -Eq '^ auto_etcd_restore_on_api_failure:[[:space:]]*(true|false)' "${CONF_DIR}/hecate.yaml"; then
|
||||
sed -Ei '/^ api_poll_seconds:[[:space:]]*[0-9]+/a\ auto_etcd_restore_on_api_failure: true\n etcd_restore_control_plane: '"${default_restore_cp}"'' "${CONF_DIR}/hecate.yaml"
|
||||
sed -Ei '/^ api_poll_seconds:[[:space:]]*[0-9]+/a\ require_time_sync: true\n time_sync_wait_seconds: 240\n time_sync_poll_seconds: 5\n reconcile_access_on_boot: true\n auto_etcd_restore_on_api_failure: true\n etcd_restore_control_plane: '"${default_restore_cp}"'' "${CONF_DIR}/hecate.yaml"
|
||||
echo "[install] added startup.auto_etcd_restore_on_api_failure + startup.etcd_restore_control_plane defaults"
|
||||
changed=1
|
||||
fi
|
||||
if grep -Eq '^ api_poll_seconds:[[:space:]]*[0-9]+' "${CONF_DIR}/hecate.yaml" \
|
||||
&& ! grep -Eq '^ require_time_sync:[[:space:]]*(true|false)' "${CONF_DIR}/hecate.yaml"; then
|
||||
sed -Ei '/^ api_poll_seconds:[[:space:]]*[0-9]+/a\ require_time_sync: true\n time_sync_wait_seconds: 240\n time_sync_poll_seconds: 5\n reconcile_access_on_boot: true' "${CONF_DIR}/hecate.yaml"
|
||||
echo "[install] added startup time sync + access reconciliation defaults"
|
||||
changed=1
|
||||
fi
|
||||
|
||||
local role
|
||||
role="$(read_hecate_role)"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user