diff --git a/configs/hecate.example.yaml b/configs/hecate.example.yaml index df9d904..e1eacea 100644 --- a/configs/hecate.example.yaml +++ b/configs/hecate.example.yaml @@ -46,6 +46,9 @@ shutdown: default_budget_seconds: 300 skip_etcd_snapshot: false skip_drain: false + drain_parallelism: 6 + scale_parallelism: 8 + ssh_parallelism: 8 poweroff_enabled: true poweroff_delay_seconds: 25 poweroff_local_host: true diff --git a/configs/hecate.tethys.yaml b/configs/hecate.tethys.yaml index c1c8a74..ea195d1 100644 --- a/configs/hecate.tethys.yaml +++ b/configs/hecate.tethys.yaml @@ -63,6 +63,9 @@ shutdown: default_budget_seconds: 300 skip_etcd_snapshot: false skip_drain: false + drain_parallelism: 6 + scale_parallelism: 8 + ssh_parallelism: 8 poweroff_enabled: true poweroff_delay_seconds: 25 poweroff_local_host: true diff --git a/configs/hecate.titan-db.yaml b/configs/hecate.titan-db.yaml index 5f472c4..8d7689e 100644 --- a/configs/hecate.titan-db.yaml +++ b/configs/hecate.titan-db.yaml @@ -80,6 +80,9 @@ shutdown: default_budget_seconds: 300 skip_etcd_snapshot: false skip_drain: false + drain_parallelism: 6 + scale_parallelism: 8 + ssh_parallelism: 8 poweroff_enabled: true poweroff_delay_seconds: 25 poweroff_local_host: true diff --git a/internal/cluster/orchestrator.go b/internal/cluster/orchestrator.go index c47c02d..a946f53 100644 --- a/internal/cluster/orchestrator.go +++ b/internal/cluster/orchestrator.go @@ -12,6 +12,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "scm.bstein.dev/bstein/hecate/internal/config" @@ -44,6 +45,18 @@ type startupWorkload struct { Name string } +type workloadScaleEntry struct { + Namespace string `json:"namespace"` + Kind string `json:"kind"` + Name string `json:"name"` + Replicas int `json:"replicas"` +} + +type workloadScaleSnapshot struct { + GeneratedAt time.Time `json:"generated_at"` + Entries []workloadScaleEntry `json:"entries"` +} + var criticalStartupWorkloads = []startupWorkload{ {Namespace: "flux-system", Kind: "deployment", Name: "source-controller"}, {Namespace: "flux-system", Kind: "deployment", Name: "kustomize-controller"}, @@ -176,6 +189,8 @@ func (o *Orchestrator) Startup(ctx context.Context, opts StartupOptions) (err er } } + o.bestEffort("restore scaled workloads", func() error { return o.restoreScaledApps(ctx) }) + if err := o.resumeFluxAndReconcile(ctx); err != nil { return err } @@ -330,42 +345,278 @@ func (o *Orchestrator) patchFluxSuspendAll(ctx context.Context, suspend bool) er } func (o *Orchestrator) scaleDownApps(ctx context.Context) error { - nsOut, err := o.kubectl(ctx, 15*time.Second, "get", "ns", "-o", "jsonpath={range .items[*]}{.metadata.name}{'\\n'}{end}") + targets, err := o.listScalableWorkloads(ctx) if err != nil { return err } - exclude := map[string]struct{}{} - for _, ns := range o.cfg.ExcludedNamespaces { - exclude[ns] = struct{}{} + if err := o.writeScaledWorkloadSnapshot(targets); err != nil { + return err } - for _, ns := range lines(nsOut) { - if _, ok := exclude[ns]; ok { - continue - } - if _, scaleErr := o.kubectl(ctx, 15*time.Second, "-n", ns, "scale", "deployment", "--all", "--replicas=0"); scaleErr != nil { - o.log.Printf("warning: scale deployments in %s failed: %v", ns, scaleErr) - } - if _, scaleErr := o.kubectl(ctx, 15*time.Second, "-n", ns, "scale", "statefulset", "--all", "--replicas=0"); scaleErr != nil { - o.log.Printf("warning: scale statefulsets in %s failed: %v", ns, scaleErr) - } + if len(targets) == 0 { + o.log.Printf("scale down apps: no workloads above 0 replicas") + return nil + } + + parallelism := o.cfg.Shutdown.ScaleParallelism + if parallelism <= 0 { + parallelism = 8 + } + o.log.Printf("scale down apps targets=%d parallelism=%d", len(targets), parallelism) + return o.scaleWorkloads(ctx, targets, 0, parallelism) +} + +func (o *Orchestrator) restoreScaledApps(ctx context.Context) error { + snapshot, err := o.readScaledWorkloadSnapshot() + if err != nil { + return err + } + if snapshot == nil || len(snapshot.Entries) == 0 { + o.log.Printf("restore scaled workloads: no snapshot entries to restore") + return nil + } + + parallelism := o.cfg.Shutdown.ScaleParallelism + if parallelism <= 0 { + parallelism = 8 + } + o.log.Printf("restore scaled workloads entries=%d parallelism=%d snapshot_at=%s", len(snapshot.Entries), parallelism, snapshot.GeneratedAt.Format(time.RFC3339)) + if err := o.scaleWorkloads(ctx, snapshot.Entries, -1, parallelism); err != nil { + return err + } + if o.runner.DryRun { + return nil + } + if err := os.Remove(o.scaledWorkloadSnapshotPath()); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("remove scaled workload snapshot: %w", err) } return nil } -func (o *Orchestrator) drainWorkers(ctx context.Context, workers []string) error { - total := len(workers) - for idx, node := range workers { - o.log.Printf("drain worker %d/%d: %s", idx+1, total, node) - if _, err := o.kubectl(ctx, 20*time.Second, "cordon", node); err != nil { - o.log.Printf("warning: cordon %s failed: %v", node, err) +func (o *Orchestrator) listScalableWorkloads(ctx context.Context) ([]workloadScaleEntry, error) { + exclude := map[string]struct{}{} + for _, ns := range o.cfg.ExcludedNamespaces { + exclude[strings.TrimSpace(ns)] = struct{}{} + } + + collect := func(kind string) ([]workloadScaleEntry, error) { + out, err := o.kubectl( + ctx, + 25*time.Second, + "get", + kind, + "-A", + "-o", + "jsonpath={range .items[*]}{.metadata.namespace}{'\\t'}{.metadata.name}{'\\t'}{.spec.replicas}{'\\n'}{end}", + ) + if err != nil { + return nil, err } - if _, err := o.kubectl(ctx, 3*time.Minute, "drain", node, "--ignore-daemonsets", "--delete-emptydir-data", "--grace-period=30", "--timeout=180s"); err != nil { - o.log.Printf("warning: drain %s failed: %v", node, err) + var entries []workloadScaleEntry + for _, line := range lines(out) { + parts := strings.Split(line, "\t") + if len(parts) < 3 { + continue + } + ns := strings.TrimSpace(parts[0]) + if _, skip := exclude[ns]; skip { + continue + } + replicas, convErr := strconv.Atoi(strings.TrimSpace(parts[2])) + if convErr != nil || replicas <= 0 { + continue + } + entries = append(entries, workloadScaleEntry{ + Namespace: ns, + Kind: kind, + Name: strings.TrimSpace(parts[1]), + Replicas: replicas, + }) + } + return entries, nil + } + + deployments, err := collect("deployment") + if err != nil { + return nil, fmt.Errorf("collect deployments: %w", err) + } + statefulsets, err := collect("statefulset") + if err != nil { + return nil, fmt.Errorf("collect statefulsets: %w", err) + } + targets := append(deployments, statefulsets...) + sort.Slice(targets, func(i, j int) bool { + a, b := targets[i], targets[j] + if a.Namespace != b.Namespace { + return a.Namespace < b.Namespace + } + if a.Kind != b.Kind { + return a.Kind < b.Kind + } + return a.Name < b.Name + }) + return targets, nil +} + +func (o *Orchestrator) scaleWorkloads(ctx context.Context, entries []workloadScaleEntry, forceReplicas int, parallelism int) error { + if len(entries) == 0 { + return nil + } + if parallelism <= 0 { + parallelism = 1 + } + if parallelism > len(entries) { + parallelism = len(entries) + } + + sem := make(chan struct{}, parallelism) + var wg sync.WaitGroup + errCh := make(chan error, len(entries)) + + for _, entry := range entries { + entry := entry + wg.Add(1) + go func() { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + + replicas := entry.Replicas + if forceReplicas >= 0 { + replicas = forceReplicas + } + if _, err := o.kubectl( + ctx, + 20*time.Second, + "-n", + entry.Namespace, + "scale", + entry.Kind, + entry.Name, + fmt.Sprintf("--replicas=%d", replicas), + ); err != nil { + if isNotFoundErr(err) { + o.log.Printf("warning: skip missing workload while scaling %s/%s/%s", entry.Namespace, entry.Kind, entry.Name) + return + } + errCh <- fmt.Errorf("scale %s/%s/%s -> %d: %w", entry.Namespace, entry.Kind, entry.Name, replicas, err) + } + }() + } + + wg.Wait() + close(errCh) + + errorCount := len(errCh) + if errorCount == 0 { + return nil + } + errs := make([]string, 0, len(errCh)) + for err := range errCh { + errs = append(errs, err.Error()) + if len(errs) >= 5 { + break } } + return fmt.Errorf("scaling had %d errors (first: %s)", errorCount, strings.Join(errs, " | ")) +} + +func (o *Orchestrator) scaledWorkloadSnapshotPath() string { + return filepath.Join(o.cfg.State.Dir, "scaled-workloads.json") +} + +func (o *Orchestrator) writeScaledWorkloadSnapshot(entries []workloadScaleEntry) error { + if o.runner.DryRun { + return nil + } + if err := os.MkdirAll(o.cfg.State.Dir, 0o755); err != nil { + return fmt.Errorf("ensure state dir: %w", err) + } + payload := workloadScaleSnapshot{ + GeneratedAt: time.Now().UTC(), + Entries: entries, + } + b, err := json.MarshalIndent(payload, "", " ") + if err != nil { + return fmt.Errorf("marshal scaled workload snapshot: %w", err) + } + if err := os.WriteFile(o.scaledWorkloadSnapshotPath(), b, 0o644); err != nil { + return fmt.Errorf("write scaled workload snapshot: %w", err) + } return nil } +func (o *Orchestrator) readScaledWorkloadSnapshot() (*workloadScaleSnapshot, error) { + if o.runner.DryRun { + return nil, nil + } + b, err := os.ReadFile(o.scaledWorkloadSnapshotPath()) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("read scaled workload snapshot: %w", err) + } + var snapshot workloadScaleSnapshot + if err := json.Unmarshal(b, &snapshot); err != nil { + return nil, fmt.Errorf("decode scaled workload snapshot: %w", err) + } + return &snapshot, nil +} + +func (o *Orchestrator) drainWorkers(ctx context.Context, workers []string) error { + total := len(workers) + if total == 0 { + return nil + } + parallelism := o.cfg.Shutdown.DrainParallelism + if parallelism <= 0 { + parallelism = 6 + } + if parallelism > total { + parallelism = total + } + + o.log.Printf("drain workers total=%d parallelism=%d", total, parallelism) + sem := make(chan struct{}, parallelism) + var wg sync.WaitGroup + errCh := make(chan error, total) + + for idx, node := range workers { + idx := idx + node := node + wg.Add(1) + go func() { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + + o.log.Printf("drain worker %d/%d: %s", idx+1, total, node) + if _, err := o.kubectl(ctx, 20*time.Second, "cordon", node); err != nil { + o.log.Printf("warning: cordon %s failed: %v", node, err) + } + if _, err := o.kubectl(ctx, 3*time.Minute, "drain", node, "--ignore-daemonsets", "--delete-emptydir-data", "--grace-period=30", "--timeout=180s"); err != nil { + errCh <- fmt.Errorf("drain %s failed: %w", node, err) + return + } + }() + } + + wg.Wait() + close(errCh) + if len(errCh) == 0 { + return nil + } + count := len(errCh) + samples := []string{} + for err := range errCh { + samples = append(samples, err.Error()) + if len(samples) >= 4 { + break + } + } + return fmt.Errorf("drain workers had %d errors (first: %s)", count, strings.Join(samples, " | ")) +} + func (o *Orchestrator) uncordonWorkers(ctx context.Context, workers []string) error { for _, node := range workers { if _, err := o.kubectl(ctx, 20*time.Second, "uncordon", node); err != nil { @@ -376,55 +627,53 @@ func (o *Orchestrator) uncordonWorkers(ctx context.Context, workers []string) er } func (o *Orchestrator) stopWorkers(ctx context.Context, workers []string) { - for _, n := range workers { - if !o.sshManaged(n) { - o.log.Printf("skip stop k3s-agent on %s: node not in ssh_managed_nodes", n) - continue - } - o.bestEffort("stop k3s-agent on "+n, func() error { - _, err := o.ssh(ctx, n, "sudo systemctl stop k3s-agent || true") - return err - }) - } + o.runSSHAcrossNodes(ctx, workers, "stop k3s-agent", "sudo systemctl stop k3s-agent || true") } func (o *Orchestrator) startWorkers(ctx context.Context, workers []string) { - for _, n := range workers { - if !o.sshManaged(n) { - o.log.Printf("skip start k3s-agent on %s: node not in ssh_managed_nodes", n) - continue - } - o.bestEffort("start k3s-agent on "+n, func() error { - _, err := o.ssh(ctx, n, "sudo systemctl start k3s-agent || true") - return err - }) - } + o.runSSHAcrossNodes(ctx, workers, "start k3s-agent", "sudo systemctl start k3s-agent || true") } func (o *Orchestrator) stopControlPlanes(ctx context.Context, cps []string) { - for _, n := range cps { - if !o.sshManaged(n) { - o.log.Printf("skip stop k3s on %s: node not in ssh_managed_nodes", n) - continue - } - o.bestEffort("stop k3s on "+n, func() error { - _, err := o.ssh(ctx, n, "sudo systemctl stop k3s || true") - return err - }) - } + o.runSSHAcrossNodes(ctx, cps, "stop k3s", "sudo systemctl stop k3s || true") } func (o *Orchestrator) startControlPlanes(ctx context.Context, cps []string) { - for _, n := range cps { - if !o.sshManaged(n) { - o.log.Printf("skip start k3s on %s: node not in ssh_managed_nodes", n) + o.runSSHAcrossNodes(ctx, cps, "start k3s", "sudo systemctl start k3s || true") +} + +func (o *Orchestrator) runSSHAcrossNodes(ctx context.Context, nodes []string, action, command string) { + if len(nodes) == 0 { + return + } + parallelism := o.cfg.Shutdown.SSHParallelism + if parallelism <= 0 { + parallelism = 8 + } + if parallelism > len(nodes) { + parallelism = len(nodes) + } + + sem := make(chan struct{}, parallelism) + var wg sync.WaitGroup + for _, node := range nodes { + node := node + if !o.sshManaged(node) { + o.log.Printf("skip %s on %s: node not in ssh_managed_nodes", action, node) continue } - o.bestEffort("start k3s on "+n, func() error { - _, err := o.ssh(ctx, n, "sudo systemctl start k3s || true") - return err - }) + wg.Add(1) + go func() { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + o.bestEffort(action+" on "+node, func() error { + _, err := o.ssh(ctx, node, command) + return err + }) + }() } + wg.Wait() } func (o *Orchestrator) takeEtcdSnapshot(ctx context.Context, node string) error { diff --git a/internal/config/config.go b/internal/config/config.go index 6d7216d..4077f25 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -41,6 +41,9 @@ type Shutdown struct { DefaultBudgetSeconds int `yaml:"default_budget_seconds"` SkipEtcdSnapshot bool `yaml:"skip_etcd_snapshot"` SkipDrain bool `yaml:"skip_drain"` + DrainParallelism int `yaml:"drain_parallelism"` + ScaleParallelism int `yaml:"scale_parallelism"` + SSHParallelism int `yaml:"ssh_parallelism"` PoweroffEnabled bool `yaml:"poweroff_enabled"` PoweroffDelaySeconds int `yaml:"poweroff_delay_seconds"` PoweroffLocalHost bool `yaml:"poweroff_local_host"` @@ -117,6 +120,15 @@ func (c Config) Validate() error { if c.Shutdown.DefaultBudgetSeconds <= 0 { return fmt.Errorf("config.shutdown.default_budget_seconds must be > 0") } + if c.Shutdown.DrainParallelism <= 0 { + return fmt.Errorf("config.shutdown.drain_parallelism must be > 0") + } + if c.Shutdown.ScaleParallelism <= 0 { + return fmt.Errorf("config.shutdown.scale_parallelism must be > 0") + } + if c.Shutdown.SSHParallelism <= 0 { + return fmt.Errorf("config.shutdown.ssh_parallelism must be > 0") + } if c.Startup.APIWaitSeconds <= 0 { return fmt.Errorf("config.startup.api_wait_seconds must be > 0") } @@ -193,6 +205,9 @@ func defaults() Config { }, Shutdown: Shutdown{ DefaultBudgetSeconds: 1380, + DrainParallelism: 6, + ScaleParallelism: 8, + SSHParallelism: 8, PoweroffEnabled: true, PoweroffDelaySeconds: 25, PoweroffLocalHost: true, @@ -247,6 +262,15 @@ func (c *Config) applyDefaults() { if c.Shutdown.DefaultBudgetSeconds <= 0 { c.Shutdown.DefaultBudgetSeconds = 1380 } + if c.Shutdown.DrainParallelism <= 0 { + c.Shutdown.DrainParallelism = 6 + } + if c.Shutdown.ScaleParallelism <= 0 { + c.Shutdown.ScaleParallelism = 8 + } + if c.Shutdown.SSHParallelism <= 0 { + c.Shutdown.SSHParallelism = 8 + } if c.Shutdown.PoweroffDelaySeconds <= 0 { c.Shutdown.PoweroffDelaySeconds = 25 }