84 lines
1.8 KiB
Go
84 lines
1.8 KiB
Go
|
|
package writer
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
"strings"
|
||
|
|
)
|
||
|
|
|
||
|
|
// ProgressFunc receives write progress updates.
|
||
|
|
type ProgressFunc func(written int64, total int64)
|
||
|
|
|
||
|
|
// WriteImage writes src into dest using a direct buffered copy so callers can
|
||
|
|
// share the same codepath for files and block devices.
|
||
|
|
func WriteImage(ctx context.Context, src, dest string) error {
|
||
|
|
return WriteImageWithProgress(ctx, src, dest, nil)
|
||
|
|
}
|
||
|
|
|
||
|
|
// WriteImageWithProgress writes src into dest while invoking progress after each chunk.
|
||
|
|
func WriteImageWithProgress(ctx context.Context, src, dest string, progress ProgressFunc) error {
|
||
|
|
if dest == "" {
|
||
|
|
return fmt.Errorf("destination required")
|
||
|
|
}
|
||
|
|
srcInfo, err := os.Stat(src)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("source missing: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return copyFile(ctx, src, dest, srcInfo.Size(), progress)
|
||
|
|
}
|
||
|
|
|
||
|
|
func isDevicePath(path string) bool {
|
||
|
|
return strings.HasPrefix(filepath.Clean(path), "/dev/")
|
||
|
|
}
|
||
|
|
|
||
|
|
func copyFile(ctx context.Context, src, dest string, total int64, progress ProgressFunc) error {
|
||
|
|
if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
in, err := os.Open(src)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer in.Close()
|
||
|
|
out, err := os.Create(dest)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer out.Close()
|
||
|
|
buf := make([]byte, 4*1024*1024)
|
||
|
|
var written int64
|
||
|
|
for {
|
||
|
|
if err := ctx.Err(); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
nr, readErr := in.Read(buf)
|
||
|
|
if nr > 0 {
|
||
|
|
nw, writeErr := out.Write(buf[:nr])
|
||
|
|
if writeErr != nil {
|
||
|
|
return writeErr
|
||
|
|
}
|
||
|
|
if nw != nr {
|
||
|
|
return io.ErrShortWrite
|
||
|
|
}
|
||
|
|
written += int64(nw)
|
||
|
|
if progress != nil {
|
||
|
|
progress(written, total)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if readErr != nil {
|
||
|
|
if readErr == io.EOF {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
return readErr
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if err := out.Sync(); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|