diff --git a/internal/state/store.go b/internal/state/store.go index ce1c5f6..e918f47 100644 --- a/internal/state/store.go +++ b/internal/state/store.go @@ -2,12 +2,16 @@ package state import ( "encoding/json" + "errors" "fmt" "math" "os" "path/filepath" "sort" + "strconv" + "strings" "sync" + "syscall" "time" ) @@ -42,15 +46,75 @@ func AcquireLock(path string) (func(), error) { if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil { return nil, err } - f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600) - if err != nil { + + create := func() (func(), error) { + f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600) + if err != nil { + return nil, err + } + _, _ = f.WriteString(fmt.Sprintf("pid=%d started=%s\n", os.Getpid(), time.Now().Format(time.RFC3339))) + _ = f.Close() + return func() { + _ = os.Remove(path) + }, nil + } + + unlock, err := create() + if err == nil { + return unlock, nil + } + if !errors.Is(err, os.ErrExist) { return nil, fmt.Errorf("acquire lock %s: %w", path, err) } - _, _ = f.WriteString(fmt.Sprintf("pid=%d started=%s\n", os.Getpid(), time.Now().Format(time.RFC3339))) - _ = f.Close() - return func() { - _ = os.Remove(path) - }, nil + + stale, staleErr := staleLock(path) + if staleErr != nil { + return nil, fmt.Errorf("acquire lock %s: existing lock check failed: %w", path, staleErr) + } + if !stale { + return nil, fmt.Errorf("acquire lock %s: lock is held by active process", path) + } + if rmErr := os.Remove(path); rmErr != nil { + return nil, fmt.Errorf("acquire lock %s: remove stale lock: %w", path, rmErr) + } + unlock, err = create() + if err != nil { + return nil, fmt.Errorf("acquire lock %s: recreate after stale lock removal: %w", path, err) + } + return unlock, nil +} + +func staleLock(path string) (bool, error) { + b, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return true, nil + } + return false, err + } + lines := strings.Split(string(b), "\n") + var pid int + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "pid=") { + v := strings.TrimPrefix(line, "pid=") + parsed, parseErr := strconv.Atoi(v) + if parseErr != nil { + return true, nil + } + pid = parsed + break + } + } + if pid <= 0 { + return true, nil + } + if err := syscall.Kill(pid, 0); err != nil { + if errors.Is(err, syscall.ESRCH) { + return true, nil + } + } + return false, nil } func (s *Store) Append(record RunRecord) error { diff --git a/internal/state/store_test.go b/internal/state/store_test.go new file mode 100644 index 0000000..5a7a826 --- /dev/null +++ b/internal/state/store_test.go @@ -0,0 +1,57 @@ +package state + +import ( + "os" + "path/filepath" + "strconv" + "strings" + "testing" +) + +func TestAcquireLockLifecycle(t *testing.T) { + lockPath := filepath.Join(t.TempDir(), "hecate.lock") + unlock, err := AcquireLock(lockPath) + if err != nil { + t.Fatalf("acquire lock: %v", err) + } + if _, err := os.Stat(lockPath); err != nil { + t.Fatalf("expected lock file to exist: %v", err) + } + unlock() + if _, err := os.Stat(lockPath); !os.IsNotExist(err) { + t.Fatalf("expected lock file to be removed, got: %v", err) + } +} + +func TestAcquireLockReclaimsStaleLock(t *testing.T) { + lockPath := filepath.Join(t.TempDir(), "hecate.lock") + if err := os.WriteFile(lockPath, []byte("pid=999999\n"), 0o600); err != nil { + t.Fatalf("write stale lock: %v", err) + } + + unlock, err := AcquireLock(lockPath) + if err != nil { + t.Fatalf("acquire lock with stale predecessor: %v", err) + } + defer unlock() + + b, err := os.ReadFile(lockPath) + if err != nil { + t.Fatalf("read lock: %v", err) + } + if !strings.Contains(string(b), "pid=") { + t.Fatalf("expected lock content to contain pid, got %q", string(b)) + } +} + +func TestAcquireLockRejectsActiveLock(t *testing.T) { + lockPath := filepath.Join(t.TempDir(), "hecate.lock") + active := "pid=" + strconv.Itoa(os.Getpid()) + "\n" + if err := os.WriteFile(lockPath, []byte(active), 0o600); err != nil { + t.Fatalf("write active lock: %v", err) + } + + if _, err := AcquireLock(lockPath); err == nil { + t.Fatalf("expected acquire lock to fail when active pid holds lock") + } +}