package writer import ( "context" "fmt" "io" "os" "path/filepath" "strings" ) // ProgressFunc receives write progress updates. type ProgressFunc func(written int64, total int64) // WriteImage writes src into dest using a direct buffered copy so callers can // share the same codepath for files and block devices. func WriteImage(ctx context.Context, src, dest string) error { return WriteImageWithProgress(ctx, src, dest, nil) } // WriteImageWithProgress writes src into dest while invoking progress after each chunk. func WriteImageWithProgress(ctx context.Context, src, dest string, progress ProgressFunc) error { if dest == "" { return fmt.Errorf("destination required") } srcInfo, err := os.Stat(src) if err != nil { return fmt.Errorf("source missing: %w", err) } return copyFile(ctx, src, dest, srcInfo.Size(), progress) } func isDevicePath(path string) bool { return strings.HasPrefix(filepath.Clean(path), "/dev/") } func copyFile(ctx context.Context, src, dest string, total int64, progress ProgressFunc) error { if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { return err } in, err := os.Open(src) if err != nil { return err } defer in.Close() out, err := os.Create(dest) if err != nil { return err } defer out.Close() buf := make([]byte, 4*1024*1024) var written int64 for { if err := ctx.Err(); err != nil { return err } nr, readErr := in.Read(buf) if nr > 0 { nw, writeErr := out.Write(buf[:nr]) if writeErr != nil { return writeErr } if nw != nr { return io.ErrShortWrite } written += int64(nw) if progress != nil { progress(written, total) } } if readErr != nil { if readErr == io.EOF { break } return readErr } } if err := out.Sync(); err != nil { return err } return nil }