diff --git a/pkg/image/download.go b/pkg/image/download.go index e3eefb4..22bc443 100644 --- a/pkg/image/download.go +++ b/pkg/image/download.go @@ -15,21 +15,48 @@ import ( // Download fetches url into dest if dest does not exist. func Download(url, dest string) error { - if _, err := os.Stat(dest); err == nil { - return nil - } + _, err := DownloadAndVerify(url, dest, "") + return err +} + +// DownloadAndVerify fetches the source image, verifies it when a checksum is provided, +// and returns the local raw image path ready for copying or injection. +func DownloadAndVerify(url, dest, checksum string) (string, error) { if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { - return err + return "", err } if strings.HasSuffix(url, ".xz") { - tmp := dest + ".download.xz" - if err := downloadRaw(url, tmp); err != nil { - return err + archive := dest + ".xz" + if _, err := os.Stat(archive); errors.Is(err, os.ErrNotExist) { + if err := downloadRaw(url, archive); err != nil { + return "", err + } + } else if err != nil { + return "", err } - defer os.Remove(tmp) - return decompressXZ(tmp, dest) + if err := VerifyChecksum(archive, checksum); err != nil { + return "", err + } + if _, err := os.Stat(dest); errors.Is(err, os.ErrNotExist) { + if err := decompressXZ(archive, dest); err != nil { + return "", err + } + } else if err != nil { + return "", err + } + return dest, nil } - return downloadRaw(url, dest) + if _, err := os.Stat(dest); errors.Is(err, os.ErrNotExist) { + if err := downloadRaw(url, dest); err != nil { + return "", err + } + } else if err != nil { + return "", err + } + if err := VerifyChecksum(dest, checksum); err != nil { + return "", err + } + return dest, nil } func downloadRaw(url, dest string) error { diff --git a/pkg/image/download_test.go b/pkg/image/download_test.go index 464f391..16b825f 100644 --- a/pkg/image/download_test.go +++ b/pkg/image/download_test.go @@ -27,8 +27,47 @@ func TestDownloadDecompressesXZFileURLs(t *testing.T) { if err := Download("file://"+compressed, dest); err != nil { t.Fatalf("Download: %v", err) } - sum := sha256.Sum256([]byte("metis-xz-test")) - if err := VerifyChecksum(dest, "sha256:"+hex.EncodeToString(sum[:])); err != nil { - t.Fatalf("VerifyChecksum: %v", err) + data, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(data) != "metis-xz-test" { + t.Fatalf("unexpected decompressed content: %q", string(data)) + } +} + +func TestDownloadAndVerifyUsesArchiveChecksumForXZ(t *testing.T) { + if _, err := exec.LookPath("xz"); err != nil { + t.Skip("xz not available") + } + dir := t.TempDir() + raw := filepath.Join(dir, "base.img") + if err := os.WriteFile(raw, []byte("metis-xz-test"), 0o644); err != nil { + t.Fatal(err) + } + compressed := raw + ".xz" + cmd := exec.Command("xz", "-zk", raw) + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("xz: %v: %s", err, string(out)) + } + archiveBytes, err := os.ReadFile(compressed) + if err != nil { + t.Fatalf("ReadFile archive: %v", err) + } + archiveSum := sha256.Sum256(archiveBytes) + dest := filepath.Join(dir, "copy.img") + localPath, err := DownloadAndVerify("file://"+compressed, dest, "sha256:"+hex.EncodeToString(archiveSum[:])) + if err != nil { + t.Fatalf("DownloadAndVerify: %v", err) + } + if localPath != dest { + t.Fatalf("expected local path %s, got %s", dest, localPath) + } + data, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("ReadFile dest: %v", err) + } + if string(data) != "metis-xz-test" { + t.Fatalf("unexpected decompressed content: %q", string(data)) } } diff --git a/pkg/plan/burn.go b/pkg/plan/burn.go index 54fff50..23b0526 100644 --- a/pkg/plan/burn.go +++ b/pkg/plan/burn.go @@ -18,13 +18,11 @@ func Execute(inv *inventory.Inventory, nodeName, device, cacheDir string, confir if err != nil { return nil, err } - cacheImage := filepath.Join(cacheDir, filepath.Base(p.Image)) - if err := image.Download(p.Image, cacheImage); err != nil { + cacheImage := filepath.Join(cacheDir, cacheName(p.Image)) + cacheImage, err = image.DownloadAndVerify(p.Image, cacheImage, checksumFromInventory(inv, nodeName)) + if err != nil { return p, fmt.Errorf("download image: %w", err) } - if err := image.VerifyChecksum(cacheImage, checksumFromInventory(inv, nodeName)); err != nil { - return p, err - } if !confirm { return p, nil } diff --git a/pkg/plan/image_build.go b/pkg/plan/image_build.go index 6baf6e8..d06c7fa 100644 --- a/pkg/plan/image_build.go +++ b/pkg/plan/image_build.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "path/filepath" + "strings" "metis/pkg/image" "metis/pkg/inventory" @@ -21,13 +22,11 @@ func BuildImageFile(ctx context.Context, inv *inventory.Inventory, nodeName, cac return fmt.Errorf("load node class: %w", err) } - cacheImage := filepath.Join(cacheDir, filepath.Base(p.Image)) - if err := image.Download(p.Image, cacheImage); err != nil { + cacheImage := filepath.Join(cacheDir, cacheName(p.Image)) + cacheImage, err = image.DownloadAndVerify(p.Image, cacheImage, class.Checksum) + if err != nil { return fmt.Errorf("download image: %w", err) } - if err := image.VerifyChecksum(cacheImage, class.Checksum); err != nil { - return fmt.Errorf("verify checksum: %w", err) - } if err := writer.WriteImage(ctx, cacheImage, output); err != nil { return fmt.Errorf("copy base image: %w", err) } @@ -41,3 +40,8 @@ func BuildImageFile(ctx context.Context, inv *inventory.Inventory, nodeName, cac } return nil } + +func cacheName(source string) string { + base := filepath.Base(source) + return strings.TrimSuffix(base, ".xz") +} diff --git a/pkg/service/app.go b/pkg/service/app.go index 7c9f5a3..cec0b8d 100644 --- a/pkg/service/app.go +++ b/pkg/service/app.go @@ -377,7 +377,7 @@ func (a *App) runBuild(job *Job, flash bool) { a.setJob(job.ID, func(j *Job) { j.Status = JobRunning j.Stage = "download" - j.Message = "Fetching base image" + j.Message = "Fetching and verifying base image" j.ProgressPct = 5 }) output := a.artifactPath(job.Node) @@ -395,18 +395,9 @@ func (a *App) runBuild(job *Job, flash bool) { a.metrics.RecordBuild(job.Node, "error") return } - cacheImage := filepath.Join(cacheDir, filepath.Base(planData.Image)) - if err := image.Download(planData.Image, cacheImage); err != nil { - a.failJob(job.ID, err) - a.metrics.RecordBuild(job.Node, "error") - return - } - a.setJob(job.ID, func(j *Job) { - j.Stage = "verify" - j.Message = "Verifying base image checksum" - j.ProgressPct = 18 - }) - if err := image.VerifyChecksum(cacheImage, class.Checksum); err != nil { + cacheImage := filepath.Join(cacheDir, cachedImageName(planData.Image)) + cacheImage, err = image.DownloadAndVerify(planData.Image, cacheImage, class.Checksum) + if err != nil { a.failJob(job.ID, err) a.metrics.RecordBuild(job.Node, "error") return @@ -414,7 +405,7 @@ func (a *App) runBuild(job *Job, flash bool) { a.setJob(job.ID, func(j *Job) { j.Stage = "copy" j.Message = "Copying base image into artifact" - j.ProgressPct = 35 + j.ProgressPct = 24 }) if err := writer.WriteImage(context.Background(), cacheImage, output); err != nil { a.failJob(job.ID, err) @@ -649,6 +640,10 @@ func (a *App) artifactPath(node string) string { return filepath.Join(a.settings.ArtifactDir, fmt.Sprintf("%s.img", node)) } +func cachedImageName(source string) string { + return strings.TrimSuffix(filepath.Base(source), ".xz") +} + func (a *App) flashHosts() []string { hosts := map[string]struct{}{} for _, host := range a.settings.FlashHosts { diff --git a/pkg/service/server.go b/pkg/service/server.go index 9f794c0..4995d74 100644 --- a/pkg/service/server.go +++ b/pkg/service/server.go @@ -518,6 +518,7 @@ var metisPage = template.Must(template.New("metis").Parse(`
+