package sshutil import ( "context" "fmt" "log" "os" "os/exec" "path/filepath" "strings" "syscall" "time" ) var hostKeyErrorMarkers = []string{ "remote host identification has changed", "host key verification failed", "offending ", "possible dns spoofing detected", } func IsHostKeyError(output string, err error) bool { if err == nil { return false } combined := strings.ToLower(strings.TrimSpace(output + "\n" + err.Error())) if combined == "" { return false } for _, marker := range hostKeyErrorMarkers { if strings.Contains(combined, marker) { return true } } return false } func ShouldAttemptKnownHostsRepair(output string, err error) bool { if IsHostKeyError(output, err) { return true } if err == nil { return false } // Some SSH invocations (especially under strict non-interactive configs) // return exit 255 without forwarding the host-key mismatch text. if strings.Contains(strings.ToLower(err.Error()), "exit status 255") && strings.TrimSpace(output) == "" { return true } return false } func KnownHostsFiles(sshConfigFile, sshIdentityFile string) []string { seen := map[string]struct{}{} add := func(path string) { p := strings.TrimSpace(path) if p == "" { return } if _, ok := seen[p]; ok { return } seen[p] = struct{}{} } // Common locations for this environment. add("/root/.ssh/known_hosts") add("/home/atlas/.ssh/known_hosts") add("/home/tethys/.ssh/known_hosts") if home, err := os.UserHomeDir(); err == nil && strings.TrimSpace(home) != "" { add(filepath.Join(home, ".ssh", "known_hosts")) } if cfg := strings.TrimSpace(sshConfigFile); cfg != "" { add(filepath.Join(filepath.Dir(cfg), "known_hosts")) } if key := strings.TrimSpace(sshIdentityFile); key != "" { add(filepath.Join(filepath.Dir(key), "known_hosts")) } out := make([]string, 0, len(seen)) for path := range seen { out = append(out, path) } return out } func RepairKnownHosts(ctx context.Context, logger *log.Logger, knownHostsFiles []string, hosts []string, port int) { if _, err := exec.LookPath("ssh-keygen"); err != nil { logf(logger, "warning: cannot repair known_hosts (ssh-keygen missing): %v", err) return } dedupHosts := make([]string, 0, len(hosts)) hostSet := map[string]struct{}{} for _, h := range hosts { host := strings.TrimSpace(h) if host == "" { continue } if _, ok := hostSet[host]; ok { continue } hostSet[host] = struct{}{} dedupHosts = append(dedupHosts, host) } if len(dedupHosts) == 0 { return } fileSet := map[string]struct{}{} for _, f := range knownHostsFiles { file := strings.TrimSpace(f) if file == "" { continue } if _, ok := fileSet[file]; ok { continue } fileSet[file] = struct{}{} } for file := range fileSet { if stat, err := os.Stat(file); err != nil || stat.IsDir() { continue } for _, host := range dedupHosts { removeKnownHostEntry(ctx, logger, file, host) if port > 0 { removeKnownHostEntry(ctx, logger, file, fmt.Sprintf("[%s]:%d", host, port)) } } } } func removeKnownHostEntry(ctx context.Context, logger *log.Logger, file string, entry string) { uid, gid, mode := captureOwnership(file) runCtx, cancel := context.WithTimeout(ctx, 8*time.Second) defer cancel() cmd := exec.CommandContext(runCtx, "ssh-keygen", "-R", entry, "-f", file) out, err := cmd.CombinedOutput() restoreOwnership(file, file+".old", uid, gid, mode) if err == nil { logf(logger, "known_hosts repaired: removed %s from %s", entry, file) return } trimmed := strings.ToLower(strings.TrimSpace(string(out))) // ssh-keygen exits non-zero when entry is absent; this is fine. if strings.Contains(trimmed, "not found in") { return } logf(logger, "warning: known_hosts cleanup failed for %s in %s: %v: %s", entry, file, err, strings.TrimSpace(string(out))) } func captureOwnership(path string) (int, int, os.FileMode) { info, err := os.Stat(path) if err != nil { return -1, -1, 0 } st, ok := info.Sys().(*syscall.Stat_t) if !ok { return -1, -1, info.Mode().Perm() } return int(st.Uid), int(st.Gid), info.Mode().Perm() } func restoreOwnership(path string, backupPath string, uid int, gid int, mode os.FileMode) { if uid < 0 || gid < 0 { return } for _, candidate := range []string{path, backupPath} { if candidate == "" { continue } if _, err := os.Stat(candidate); err != nil { continue } _ = os.Chown(candidate, uid, gid) if mode != 0 { _ = os.Chmod(candidate, mode) } } } func logf(logger *log.Logger, format string, args ...any) { if logger != nil { logger.Printf(format, args...) } }