shutdown: parallelize drain and restore scaled workloads

This commit is contained in:
Brad Stein 2026-04-04 15:15:34 -03:00
parent ac2fbf89cb
commit 4b0fffd5e2
5 changed files with 341 additions and 59 deletions

View File

@ -46,6 +46,9 @@ shutdown:
default_budget_seconds: 300
skip_etcd_snapshot: false
skip_drain: false
drain_parallelism: 6
scale_parallelism: 8
ssh_parallelism: 8
poweroff_enabled: true
poweroff_delay_seconds: 25
poweroff_local_host: true

View File

@ -63,6 +63,9 @@ shutdown:
default_budget_seconds: 300
skip_etcd_snapshot: false
skip_drain: false
drain_parallelism: 6
scale_parallelism: 8
ssh_parallelism: 8
poweroff_enabled: true
poweroff_delay_seconds: 25
poweroff_local_host: true

View File

@ -80,6 +80,9 @@ shutdown:
default_budget_seconds: 300
skip_etcd_snapshot: false
skip_drain: false
drain_parallelism: 6
scale_parallelism: 8
ssh_parallelism: 8
poweroff_enabled: true
poweroff_delay_seconds: 25
poweroff_local_host: true

View File

@ -12,6 +12,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"
"scm.bstein.dev/bstein/hecate/internal/config"
@ -44,6 +45,18 @@ type startupWorkload struct {
Name string
}
type workloadScaleEntry struct {
Namespace string `json:"namespace"`
Kind string `json:"kind"`
Name string `json:"name"`
Replicas int `json:"replicas"`
}
type workloadScaleSnapshot struct {
GeneratedAt time.Time `json:"generated_at"`
Entries []workloadScaleEntry `json:"entries"`
}
var criticalStartupWorkloads = []startupWorkload{
{Namespace: "flux-system", Kind: "deployment", Name: "source-controller"},
{Namespace: "flux-system", Kind: "deployment", Name: "kustomize-controller"},
@ -176,6 +189,8 @@ func (o *Orchestrator) Startup(ctx context.Context, opts StartupOptions) (err er
}
}
o.bestEffort("restore scaled workloads", func() error { return o.restoreScaledApps(ctx) })
if err := o.resumeFluxAndReconcile(ctx); err != nil {
return err
}
@ -330,42 +345,278 @@ func (o *Orchestrator) patchFluxSuspendAll(ctx context.Context, suspend bool) er
}
func (o *Orchestrator) scaleDownApps(ctx context.Context) error {
nsOut, err := o.kubectl(ctx, 15*time.Second, "get", "ns", "-o", "jsonpath={range .items[*]}{.metadata.name}{'\\n'}{end}")
targets, err := o.listScalableWorkloads(ctx)
if err != nil {
return err
}
exclude := map[string]struct{}{}
for _, ns := range o.cfg.ExcludedNamespaces {
exclude[ns] = struct{}{}
if err := o.writeScaledWorkloadSnapshot(targets); err != nil {
return err
}
for _, ns := range lines(nsOut) {
if _, ok := exclude[ns]; ok {
continue
}
if _, scaleErr := o.kubectl(ctx, 15*time.Second, "-n", ns, "scale", "deployment", "--all", "--replicas=0"); scaleErr != nil {
o.log.Printf("warning: scale deployments in %s failed: %v", ns, scaleErr)
}
if _, scaleErr := o.kubectl(ctx, 15*time.Second, "-n", ns, "scale", "statefulset", "--all", "--replicas=0"); scaleErr != nil {
o.log.Printf("warning: scale statefulsets in %s failed: %v", ns, scaleErr)
}
if len(targets) == 0 {
o.log.Printf("scale down apps: no workloads above 0 replicas")
return nil
}
parallelism := o.cfg.Shutdown.ScaleParallelism
if parallelism <= 0 {
parallelism = 8
}
o.log.Printf("scale down apps targets=%d parallelism=%d", len(targets), parallelism)
return o.scaleWorkloads(ctx, targets, 0, parallelism)
}
func (o *Orchestrator) restoreScaledApps(ctx context.Context) error {
snapshot, err := o.readScaledWorkloadSnapshot()
if err != nil {
return err
}
if snapshot == nil || len(snapshot.Entries) == 0 {
o.log.Printf("restore scaled workloads: no snapshot entries to restore")
return nil
}
parallelism := o.cfg.Shutdown.ScaleParallelism
if parallelism <= 0 {
parallelism = 8
}
o.log.Printf("restore scaled workloads entries=%d parallelism=%d snapshot_at=%s", len(snapshot.Entries), parallelism, snapshot.GeneratedAt.Format(time.RFC3339))
if err := o.scaleWorkloads(ctx, snapshot.Entries, -1, parallelism); err != nil {
return err
}
if o.runner.DryRun {
return nil
}
if err := os.Remove(o.scaledWorkloadSnapshotPath()); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove scaled workload snapshot: %w", err)
}
return nil
}
func (o *Orchestrator) drainWorkers(ctx context.Context, workers []string) error {
total := len(workers)
for idx, node := range workers {
o.log.Printf("drain worker %d/%d: %s", idx+1, total, node)
if _, err := o.kubectl(ctx, 20*time.Second, "cordon", node); err != nil {
o.log.Printf("warning: cordon %s failed: %v", node, err)
func (o *Orchestrator) listScalableWorkloads(ctx context.Context) ([]workloadScaleEntry, error) {
exclude := map[string]struct{}{}
for _, ns := range o.cfg.ExcludedNamespaces {
exclude[strings.TrimSpace(ns)] = struct{}{}
}
collect := func(kind string) ([]workloadScaleEntry, error) {
out, err := o.kubectl(
ctx,
25*time.Second,
"get",
kind,
"-A",
"-o",
"jsonpath={range .items[*]}{.metadata.namespace}{'\\t'}{.metadata.name}{'\\t'}{.spec.replicas}{'\\n'}{end}",
)
if err != nil {
return nil, err
}
if _, err := o.kubectl(ctx, 3*time.Minute, "drain", node, "--ignore-daemonsets", "--delete-emptydir-data", "--grace-period=30", "--timeout=180s"); err != nil {
o.log.Printf("warning: drain %s failed: %v", node, err)
var entries []workloadScaleEntry
for _, line := range lines(out) {
parts := strings.Split(line, "\t")
if len(parts) < 3 {
continue
}
ns := strings.TrimSpace(parts[0])
if _, skip := exclude[ns]; skip {
continue
}
replicas, convErr := strconv.Atoi(strings.TrimSpace(parts[2]))
if convErr != nil || replicas <= 0 {
continue
}
entries = append(entries, workloadScaleEntry{
Namespace: ns,
Kind: kind,
Name: strings.TrimSpace(parts[1]),
Replicas: replicas,
})
}
return entries, nil
}
deployments, err := collect("deployment")
if err != nil {
return nil, fmt.Errorf("collect deployments: %w", err)
}
statefulsets, err := collect("statefulset")
if err != nil {
return nil, fmt.Errorf("collect statefulsets: %w", err)
}
targets := append(deployments, statefulsets...)
sort.Slice(targets, func(i, j int) bool {
a, b := targets[i], targets[j]
if a.Namespace != b.Namespace {
return a.Namespace < b.Namespace
}
if a.Kind != b.Kind {
return a.Kind < b.Kind
}
return a.Name < b.Name
})
return targets, nil
}
func (o *Orchestrator) scaleWorkloads(ctx context.Context, entries []workloadScaleEntry, forceReplicas int, parallelism int) error {
if len(entries) == 0 {
return nil
}
if parallelism <= 0 {
parallelism = 1
}
if parallelism > len(entries) {
parallelism = len(entries)
}
sem := make(chan struct{}, parallelism)
var wg sync.WaitGroup
errCh := make(chan error, len(entries))
for _, entry := range entries {
entry := entry
wg.Add(1)
go func() {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
replicas := entry.Replicas
if forceReplicas >= 0 {
replicas = forceReplicas
}
if _, err := o.kubectl(
ctx,
20*time.Second,
"-n",
entry.Namespace,
"scale",
entry.Kind,
entry.Name,
fmt.Sprintf("--replicas=%d", replicas),
); err != nil {
if isNotFoundErr(err) {
o.log.Printf("warning: skip missing workload while scaling %s/%s/%s", entry.Namespace, entry.Kind, entry.Name)
return
}
errCh <- fmt.Errorf("scale %s/%s/%s -> %d: %w", entry.Namespace, entry.Kind, entry.Name, replicas, err)
}
}()
}
wg.Wait()
close(errCh)
errorCount := len(errCh)
if errorCount == 0 {
return nil
}
errs := make([]string, 0, len(errCh))
for err := range errCh {
errs = append(errs, err.Error())
if len(errs) >= 5 {
break
}
}
return fmt.Errorf("scaling had %d errors (first: %s)", errorCount, strings.Join(errs, " | "))
}
func (o *Orchestrator) scaledWorkloadSnapshotPath() string {
return filepath.Join(o.cfg.State.Dir, "scaled-workloads.json")
}
func (o *Orchestrator) writeScaledWorkloadSnapshot(entries []workloadScaleEntry) error {
if o.runner.DryRun {
return nil
}
if err := os.MkdirAll(o.cfg.State.Dir, 0o755); err != nil {
return fmt.Errorf("ensure state dir: %w", err)
}
payload := workloadScaleSnapshot{
GeneratedAt: time.Now().UTC(),
Entries: entries,
}
b, err := json.MarshalIndent(payload, "", " ")
if err != nil {
return fmt.Errorf("marshal scaled workload snapshot: %w", err)
}
if err := os.WriteFile(o.scaledWorkloadSnapshotPath(), b, 0o644); err != nil {
return fmt.Errorf("write scaled workload snapshot: %w", err)
}
return nil
}
func (o *Orchestrator) readScaledWorkloadSnapshot() (*workloadScaleSnapshot, error) {
if o.runner.DryRun {
return nil, nil
}
b, err := os.ReadFile(o.scaledWorkloadSnapshotPath())
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("read scaled workload snapshot: %w", err)
}
var snapshot workloadScaleSnapshot
if err := json.Unmarshal(b, &snapshot); err != nil {
return nil, fmt.Errorf("decode scaled workload snapshot: %w", err)
}
return &snapshot, nil
}
func (o *Orchestrator) drainWorkers(ctx context.Context, workers []string) error {
total := len(workers)
if total == 0 {
return nil
}
parallelism := o.cfg.Shutdown.DrainParallelism
if parallelism <= 0 {
parallelism = 6
}
if parallelism > total {
parallelism = total
}
o.log.Printf("drain workers total=%d parallelism=%d", total, parallelism)
sem := make(chan struct{}, parallelism)
var wg sync.WaitGroup
errCh := make(chan error, total)
for idx, node := range workers {
idx := idx
node := node
wg.Add(1)
go func() {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
o.log.Printf("drain worker %d/%d: %s", idx+1, total, node)
if _, err := o.kubectl(ctx, 20*time.Second, "cordon", node); err != nil {
o.log.Printf("warning: cordon %s failed: %v", node, err)
}
if _, err := o.kubectl(ctx, 3*time.Minute, "drain", node, "--ignore-daemonsets", "--delete-emptydir-data", "--grace-period=30", "--timeout=180s"); err != nil {
errCh <- fmt.Errorf("drain %s failed: %w", node, err)
return
}
}()
}
wg.Wait()
close(errCh)
if len(errCh) == 0 {
return nil
}
count := len(errCh)
samples := []string{}
for err := range errCh {
samples = append(samples, err.Error())
if len(samples) >= 4 {
break
}
}
return fmt.Errorf("drain workers had %d errors (first: %s)", count, strings.Join(samples, " | "))
}
func (o *Orchestrator) uncordonWorkers(ctx context.Context, workers []string) error {
for _, node := range workers {
if _, err := o.kubectl(ctx, 20*time.Second, "uncordon", node); err != nil {
@ -376,55 +627,53 @@ func (o *Orchestrator) uncordonWorkers(ctx context.Context, workers []string) er
}
func (o *Orchestrator) stopWorkers(ctx context.Context, workers []string) {
for _, n := range workers {
if !o.sshManaged(n) {
o.log.Printf("skip stop k3s-agent on %s: node not in ssh_managed_nodes", n)
continue
}
o.bestEffort("stop k3s-agent on "+n, func() error {
_, err := o.ssh(ctx, n, "sudo systemctl stop k3s-agent || true")
return err
})
}
o.runSSHAcrossNodes(ctx, workers, "stop k3s-agent", "sudo systemctl stop k3s-agent || true")
}
func (o *Orchestrator) startWorkers(ctx context.Context, workers []string) {
for _, n := range workers {
if !o.sshManaged(n) {
o.log.Printf("skip start k3s-agent on %s: node not in ssh_managed_nodes", n)
continue
}
o.bestEffort("start k3s-agent on "+n, func() error {
_, err := o.ssh(ctx, n, "sudo systemctl start k3s-agent || true")
return err
})
}
o.runSSHAcrossNodes(ctx, workers, "start k3s-agent", "sudo systemctl start k3s-agent || true")
}
func (o *Orchestrator) stopControlPlanes(ctx context.Context, cps []string) {
for _, n := range cps {
if !o.sshManaged(n) {
o.log.Printf("skip stop k3s on %s: node not in ssh_managed_nodes", n)
continue
}
o.bestEffort("stop k3s on "+n, func() error {
_, err := o.ssh(ctx, n, "sudo systemctl stop k3s || true")
return err
})
}
o.runSSHAcrossNodes(ctx, cps, "stop k3s", "sudo systemctl stop k3s || true")
}
func (o *Orchestrator) startControlPlanes(ctx context.Context, cps []string) {
for _, n := range cps {
if !o.sshManaged(n) {
o.log.Printf("skip start k3s on %s: node not in ssh_managed_nodes", n)
o.runSSHAcrossNodes(ctx, cps, "start k3s", "sudo systemctl start k3s || true")
}
func (o *Orchestrator) runSSHAcrossNodes(ctx context.Context, nodes []string, action, command string) {
if len(nodes) == 0 {
return
}
parallelism := o.cfg.Shutdown.SSHParallelism
if parallelism <= 0 {
parallelism = 8
}
if parallelism > len(nodes) {
parallelism = len(nodes)
}
sem := make(chan struct{}, parallelism)
var wg sync.WaitGroup
for _, node := range nodes {
node := node
if !o.sshManaged(node) {
o.log.Printf("skip %s on %s: node not in ssh_managed_nodes", action, node)
continue
}
o.bestEffort("start k3s on "+n, func() error {
_, err := o.ssh(ctx, n, "sudo systemctl start k3s || true")
return err
})
wg.Add(1)
go func() {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
o.bestEffort(action+" on "+node, func() error {
_, err := o.ssh(ctx, node, command)
return err
})
}()
}
wg.Wait()
}
func (o *Orchestrator) takeEtcdSnapshot(ctx context.Context, node string) error {

View File

@ -41,6 +41,9 @@ type Shutdown struct {
DefaultBudgetSeconds int `yaml:"default_budget_seconds"`
SkipEtcdSnapshot bool `yaml:"skip_etcd_snapshot"`
SkipDrain bool `yaml:"skip_drain"`
DrainParallelism int `yaml:"drain_parallelism"`
ScaleParallelism int `yaml:"scale_parallelism"`
SSHParallelism int `yaml:"ssh_parallelism"`
PoweroffEnabled bool `yaml:"poweroff_enabled"`
PoweroffDelaySeconds int `yaml:"poweroff_delay_seconds"`
PoweroffLocalHost bool `yaml:"poweroff_local_host"`
@ -117,6 +120,15 @@ func (c Config) Validate() error {
if c.Shutdown.DefaultBudgetSeconds <= 0 {
return fmt.Errorf("config.shutdown.default_budget_seconds must be > 0")
}
if c.Shutdown.DrainParallelism <= 0 {
return fmt.Errorf("config.shutdown.drain_parallelism must be > 0")
}
if c.Shutdown.ScaleParallelism <= 0 {
return fmt.Errorf("config.shutdown.scale_parallelism must be > 0")
}
if c.Shutdown.SSHParallelism <= 0 {
return fmt.Errorf("config.shutdown.ssh_parallelism must be > 0")
}
if c.Startup.APIWaitSeconds <= 0 {
return fmt.Errorf("config.startup.api_wait_seconds must be > 0")
}
@ -193,6 +205,9 @@ func defaults() Config {
},
Shutdown: Shutdown{
DefaultBudgetSeconds: 1380,
DrainParallelism: 6,
ScaleParallelism: 8,
SSHParallelism: 8,
PoweroffEnabled: true,
PoweroffDelaySeconds: 25,
PoweroffLocalHost: true,
@ -247,6 +262,15 @@ func (c *Config) applyDefaults() {
if c.Shutdown.DefaultBudgetSeconds <= 0 {
c.Shutdown.DefaultBudgetSeconds = 1380
}
if c.Shutdown.DrainParallelism <= 0 {
c.Shutdown.DrainParallelism = 6
}
if c.Shutdown.ScaleParallelism <= 0 {
c.Shutdown.ScaleParallelism = 8
}
if c.Shutdown.SSHParallelism <= 0 {
c.Shutdown.SSHParallelism = 8
}
if c.Shutdown.PoweroffDelaySeconds <= 0 {
c.Shutdown.PoweroffDelaySeconds = 25
}