package sshutil import ( "context" "fmt" "log" "os" "os/exec" "path/filepath" "strings" "time" ) var hostKeyErrorMarkers = []string{ "remote host identification has changed", "host key verification failed", "offending ", "possible dns spoofing detected", } func IsHostKeyError(output string, err error) bool { if err == nil { return false } combined := strings.ToLower(strings.TrimSpace(output + "\n" + err.Error())) if combined == "" { return false } for _, marker := range hostKeyErrorMarkers { if strings.Contains(combined, marker) { return true } } return false } func KnownHostsFiles(sshConfigFile, sshIdentityFile string) []string { seen := map[string]struct{}{} add := func(path string) { p := strings.TrimSpace(path) if p == "" { return } if _, ok := seen[p]; ok { return } seen[p] = struct{}{} } // Common locations for this environment. add("/root/.ssh/known_hosts") add("/home/atlas/.ssh/known_hosts") add("/home/tethys/.ssh/known_hosts") if home, err := os.UserHomeDir(); err == nil && strings.TrimSpace(home) != "" { add(filepath.Join(home, ".ssh", "known_hosts")) } if cfg := strings.TrimSpace(sshConfigFile); cfg != "" { add(filepath.Join(filepath.Dir(cfg), "known_hosts")) } if key := strings.TrimSpace(sshIdentityFile); key != "" { add(filepath.Join(filepath.Dir(key), "known_hosts")) } out := make([]string, 0, len(seen)) for path := range seen { out = append(out, path) } return out } func RepairKnownHosts(ctx context.Context, logger *log.Logger, knownHostsFiles []string, hosts []string, port int) { if _, err := exec.LookPath("ssh-keygen"); err != nil { logf(logger, "warning: cannot repair known_hosts (ssh-keygen missing): %v", err) return } dedupHosts := make([]string, 0, len(hosts)) hostSet := map[string]struct{}{} for _, h := range hosts { host := strings.TrimSpace(h) if host == "" { continue } if _, ok := hostSet[host]; ok { continue } hostSet[host] = struct{}{} dedupHosts = append(dedupHosts, host) } if len(dedupHosts) == 0 { return } fileSet := map[string]struct{}{} for _, f := range knownHostsFiles { file := strings.TrimSpace(f) if file == "" { continue } if _, ok := fileSet[file]; ok { continue } fileSet[file] = struct{}{} } for file := range fileSet { if stat, err := os.Stat(file); err != nil || stat.IsDir() { continue } for _, host := range dedupHosts { removeKnownHostEntry(ctx, logger, file, host) if port > 0 { removeKnownHostEntry(ctx, logger, file, fmt.Sprintf("[%s]:%d", host, port)) } } } } func removeKnownHostEntry(ctx context.Context, logger *log.Logger, file string, entry string) { runCtx, cancel := context.WithTimeout(ctx, 8*time.Second) defer cancel() cmd := exec.CommandContext(runCtx, "ssh-keygen", "-R", entry, "-f", file) out, err := cmd.CombinedOutput() if err == nil { logf(logger, "known_hosts repaired: removed %s from %s", entry, file) return } trimmed := strings.ToLower(strings.TrimSpace(string(out))) // ssh-keygen exits non-zero when entry is absent; this is fine. if strings.Contains(trimmed, "not found in") { return } logf(logger, "warning: known_hosts cleanup failed for %s in %s: %v: %s", entry, file, err, strings.TrimSpace(string(out))) } func logf(logger *log.Logger, format string, args ...any) { if logger != nil { logger.Printf(format, args...) } }