diff --git a/pkg/image/download.go b/pkg/image/download.go index 22bc443..ae3aa92 100644 --- a/pkg/image/download.go +++ b/pkg/image/download.go @@ -27,38 +27,37 @@ func DownloadAndVerify(url, dest, checksum string) (string, error) { } if strings.HasSuffix(url, ".xz") { archive := dest + ".xz" - if _, err := os.Stat(archive); errors.Is(err, os.ErrNotExist) { - if err := downloadRaw(url, archive); err != nil { - return "", err - } - } else if err != nil { + if err := ensureVerifiedFile(url, archive, checksum); err != nil { return "", err } - if err := VerifyChecksum(archive, checksum); err != nil { - return "", err - } - if _, err := os.Stat(dest); errors.Is(err, os.ErrNotExist) { - if err := decompressXZ(archive, dest); err != nil { - return "", err - } - } else if err != nil { + if err := decompressXZ(archive, dest); err != nil { return "", err } return dest, nil } - if _, err := os.Stat(dest); errors.Is(err, os.ErrNotExist) { - if err := downloadRaw(url, dest); err != nil { - return "", err - } - } else if err != nil { - return "", err - } - if err := VerifyChecksum(dest, checksum); err != nil { + if err := ensureVerifiedFile(url, dest, checksum); err != nil { return "", err } return dest, nil } +func ensureVerifiedFile(url, dest, checksum string) error { + if _, err := os.Stat(dest); err == nil { + if err := VerifyChecksum(dest, checksum); err == nil || checksum == "" { + return nil + } + if removeErr := os.Remove(dest); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) { + return removeErr + } + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + if err := downloadRaw(url, dest); err != nil { + return err + } + return VerifyChecksum(dest, checksum) +} + func downloadRaw(url, dest string) error { if strings.HasPrefix(url, "file://") { src := strings.TrimPrefix(url, "file://") diff --git a/pkg/image/download_test.go b/pkg/image/download_test.go index 16b825f..107a386 100644 --- a/pkg/image/download_test.go +++ b/pkg/image/download_test.go @@ -71,3 +71,39 @@ func TestDownloadAndVerifyUsesArchiveChecksumForXZ(t *testing.T) { t.Fatalf("unexpected decompressed content: %q", string(data)) } } + +func TestDownloadAndVerifyReplacesStaleBadArchiveCache(t *testing.T) { + if _, err := exec.LookPath("xz"); err != nil { + t.Skip("xz not available") + } + dir := t.TempDir() + raw := filepath.Join(dir, "base.img") + if err := os.WriteFile(raw, []byte("metis-xz-test"), 0o644); err != nil { + t.Fatal(err) + } + compressed := raw + ".xz" + cmd := exec.Command("xz", "-zk", raw) + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("xz: %v: %s", err, string(out)) + } + archiveBytes, err := os.ReadFile(compressed) + if err != nil { + t.Fatalf("ReadFile archive: %v", err) + } + archiveSum := sha256.Sum256(archiveBytes) + dest := filepath.Join(dir, "copy.img") + staleArchive := dest + ".xz" + if err := os.WriteFile(staleArchive, []byte("bad-cache"), 0o644); err != nil { + t.Fatalf("WriteFile stale archive: %v", err) + } + if _, err := DownloadAndVerify("file://"+compressed, dest, "sha256:"+hex.EncodeToString(archiveSum[:])); err != nil { + t.Fatalf("DownloadAndVerify with stale archive: %v", err) + } + data, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("ReadFile dest: %v", err) + } + if string(data) != "metis-xz-test" { + t.Fatalf("unexpected decompressed content: %q", string(data)) + } +}