diff --git a/pkg/image/download.go b/pkg/image/download.go index 6c3fee6..a397267 100644 --- a/pkg/image/download.go +++ b/pkg/image/download.go @@ -2,6 +2,7 @@ package image import ( "archive/zip" + "crypto/md5" "crypto/sha256" "encoding/hex" "errors" @@ -159,26 +160,39 @@ func decompressZIP(src, dest string) error { return out.Sync() } -// VerifyChecksum checks sha256 in the form "sha256:". +// VerifyChecksum checks hashes in the form "sha256:" or "md5:". func VerifyChecksum(path, checksum string) error { if checksum == "" { return nil } parts := strings.SplitN(checksum, ":", 2) - if len(parts) != 2 || parts[0] != "sha256" { - return errors.New("unsupported checksum format; use sha256:") + if len(parts) != 2 { + return errors.New("unsupported checksum format; use sha256: or md5:") } + algo := strings.ToLower(strings.TrimSpace(parts[0])) expected := strings.ToLower(parts[1]) f, err := os.Open(path) if err != nil { return err } defer f.Close() - h := sha256.New() - if _, err := io.Copy(h, f); err != nil { - return err + var sum string + switch algo { + case "sha256": + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return err + } + sum = hex.EncodeToString(h.Sum(nil)) + case "md5": + h := md5.New() + if _, err := io.Copy(h, f); err != nil { + return err + } + sum = hex.EncodeToString(h.Sum(nil)) + default: + return errors.New("unsupported checksum format; use sha256: or md5:") } - sum := hex.EncodeToString(h.Sum(nil)) if sum != expected { return fmt.Errorf("checksum mismatch: expected %s got %s", expected, sum) } diff --git a/pkg/image/download_test.go b/pkg/image/download_test.go index be07235..d0357ae 100644 --- a/pkg/image/download_test.go +++ b/pkg/image/download_test.go @@ -2,6 +2,7 @@ package image import ( "archive/zip" + "crypto/md5" "crypto/sha256" "encoding/hex" "os" @@ -160,6 +161,18 @@ func TestDownloadAndVerifyUsesArchiveChecksumForZIP(t *testing.T) { } } +func TestVerifyChecksumAcceptsMD5(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "sample.img") + if err := os.WriteFile(path, []byte("metis-md5-test"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + sum := md5.Sum([]byte("metis-md5-test")) + if err := VerifyChecksum(path, "md5:"+hex.EncodeToString(sum[:])); err != nil { + t.Fatalf("VerifyChecksum md5: %v", err) + } +} + func writeTestZIP(path string, files map[string]string) error { out, err := os.Create(path) if err != nil {