metis/pkg/image/rootfs.go

282 lines
7.9 KiB
Go
Raw Normal View History

package image
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"metis/pkg/inject"
)
type partitionTable struct {
PartitionTable partitionTableData `json:"partitiontable"`
}
type partitionTableData struct {
SectorSize uint64 `json:"sectorsize"`
Partitions []partitionTablePart `json:"partitions"`
}
type partitionTablePart struct {
Start uint64 `json:"start"`
Size uint64 `json:"size"`
Type string `json:"type"`
}
type RootFSProgressFunc func(step string)
const (
RootFSProgressFindingPartition = "finding-partition"
RootFSProgressExtracting = "extracting-partition"
RootFSProgressWritingFiles = "writing-rootfs-files"
RootFSProgressReplacing = "replacing-partition"
)
// InjectRootFS rewrites the Linux root partition inside a raw image file without
// requiring block-device mounts. Only rootfs-targeted files are written.
func InjectRootFS(imagePath string, files []inject.FileSpec) error {
return InjectRootFSWithProgress(imagePath, files, nil)
}
// InjectRootFSWithProgress emits coarse step changes while rewriting the root partition.
func InjectRootFSWithProgress(imagePath string, files []inject.FileSpec, progress RootFSProgressFunc) error {
rootFiles := make([]inject.FileSpec, 0, len(files))
for _, f := range files {
if f.RootFS {
rootFiles = append(rootFiles, f)
}
}
if len(rootFiles) == 0 {
return nil
}
emitRootFSProgress(progress, RootFSProgressFindingPartition)
part, sectorSize, err := findLinuxPartition(imagePath)
if err != nil {
return err
}
workDir, err := os.MkdirTemp("", "metis-rootfs-")
if err != nil {
return err
}
defer os.RemoveAll(workDir)
rootImage := filepath.Join(workDir, "root.ext4")
emitRootFSProgress(progress, RootFSProgressExtracting)
if err := extractPartition(imagePath, rootImage, part, sectorSize); err != nil {
return err
}
emitRootFSProgress(progress, RootFSProgressWritingFiles)
if err := writeExt4Files(rootImage, rootFiles); err != nil {
return err
}
emitRootFSProgress(progress, RootFSProgressReplacing)
return replacePartition(imagePath, rootImage, part, sectorSize)
}
func emitRootFSProgress(progress RootFSProgressFunc, step string) {
if progress != nil {
progress(step)
}
}
func findLinuxPartition(imagePath string) (partitionTablePart, uint64, error) {
out, err := exec.Command("sfdisk", "-J", imagePath).Output()
if err != nil {
return partitionTablePart{}, 0, fmt.Errorf("sfdisk -J %s: %w", imagePath, err)
}
var table partitionTable
if err := json.Unmarshal(out, &table); err != nil {
return partitionTablePart{}, 0, fmt.Errorf("decode partition table: %w", err)
}
sectorSize := table.PartitionTable.SectorSize
if sectorSize == 0 {
sectorSize = 512
}
for i := len(table.PartitionTable.Partitions) - 1; i >= 0; i-- {
part := table.PartitionTable.Partitions[i]
if isLinuxPartitionType(part.Type) {
return part, sectorSize, nil
}
}
return partitionTablePart{}, 0, fmt.Errorf("no Linux root partition found in %s", imagePath)
}
func isLinuxPartitionType(partType string) bool {
normalized := strings.ToLower(strings.TrimSpace(partType))
switch normalized {
case "83", "8300":
return true
}
return normalized == "0fc63daf-8483-4772-8e79-3d69d8477de4"
}
func extractPartition(imagePath, outPath string, part partitionTablePart, sectorSize uint64) error {
sizeBytes := int64(part.Size * sectorSize)
offsetBytes := int64(part.Start * sectorSize)
src, err := os.Open(imagePath)
if err != nil {
return err
}
defer src.Close()
if _, err := src.Seek(offsetBytes, io.SeekStart); err != nil {
return err
}
out, err := os.Create(outPath)
if err != nil {
return err
}
defer out.Close()
if _, err := io.CopyN(out, src, sizeBytes); err != nil {
return fmt.Errorf("extract root partition: %w", err)
}
return out.Sync()
}
func replacePartition(imagePath, rootImage string, part partitionTablePart, sectorSize uint64) error {
expectedSize := int64(part.Size * sectorSize)
info, err := os.Stat(rootImage)
if err != nil {
return err
}
if info.Size() != expectedSize {
return fmt.Errorf("root partition size mismatch: expected %d got %d", expectedSize, info.Size())
}
in, err := os.Open(rootImage)
if err != nil {
return err
}
defer in.Close()
out, err := os.OpenFile(imagePath, os.O_WRONLY, 0)
if err != nil {
return err
}
defer out.Close()
if _, err := out.Seek(int64(part.Start*sectorSize), io.SeekStart); err != nil {
return err
}
if _, err := io.Copy(out, in); err != nil {
return fmt.Errorf("write root partition: %w", err)
}
return out.Sync()
}
func writeExt4Files(fsPath string, files []inject.FileSpec) error {
workDir, err := os.MkdirTemp("", "metis-ext4-")
if err != nil {
return err
}
defer os.RemoveAll(workDir)
stageDir := filepath.Join(workDir, "stage")
commandFile := filepath.Join(workDir, "commands.txt")
dirs := map[string]struct{}{}
commands := make([]string, 0, len(files)*4)
for _, f := range files {
localPath := filepath.Join(stageDir, filepath.FromSlash(f.Path))
if err := os.MkdirAll(filepath.Dir(localPath), 0o755); err != nil {
return err
}
if err := os.WriteFile(localPath, f.Content, 0o644); err != nil {
return err
}
for _, dir := range parentDirs(f.Path) {
dirs[dir] = struct{}{}
}
}
dirList := make([]string, 0, len(dirs))
for dir := range dirs {
dirList = append(dirList, dir)
}
sort.Slice(dirList, func(i, j int) bool {
leftDepth := strings.Count(dirList[i], "/")
rightDepth := strings.Count(dirList[j], "/")
if leftDepth != rightDepth {
return leftDepth < rightDepth
}
return dirList[i] < dirList[j]
})
for _, dir := range dirList {
commands = append(commands, fmt.Sprintf("mkdir %s", dir))
}
for _, f := range files {
destPath := "/" + strings.TrimPrefix(filepath.ToSlash(f.Path), "/")
localPath := filepath.Join(stageDir, filepath.FromSlash(f.Path))
commands = append(commands, fmt.Sprintf("rm %s", destPath))
commands = append(commands, fmt.Sprintf("write %s %s", localPath, destPath))
commands = append(commands, fmt.Sprintf("sif %s mode 0%o", destPath, uint32(0o100000|f.Mode.Perm())))
}
if err := os.WriteFile(commandFile, []byte(strings.Join(commands, "\n")+"\n"), 0o644); err != nil {
return err
}
cmd := exec.Command("debugfs", "-w", "-f", commandFile, fsPath)
var combined bytes.Buffer
cmd.Stdout = &combined
cmd.Stderr = &combined
if err := cmd.Run(); err != nil {
return fmt.Errorf("debugfs write failed: %w: %s", err, combined.String())
}
for _, f := range files {
if err := verifyExt4File(fsPath, f, workDir); err != nil {
return err
}
}
return nil
}
func verifyExt4File(fsPath string, file inject.FileSpec, workDir string) error {
destPath := "/" + strings.TrimPrefix(filepath.ToSlash(file.Path), "/")
statOut, err := exec.Command("debugfs", "-R", "stat "+destPath, fsPath).CombinedOutput()
if err != nil {
return fmt.Errorf("verify %s: %w: %s", destPath, err, string(statOut))
}
expectedMode := fmt.Sprintf("Mode: %04o", file.Mode.Perm())
if !strings.Contains(string(statOut), expectedMode) {
return fmt.Errorf("verify %s mode: expected %s in %s", destPath, expectedMode, string(statOut))
}
readback := filepath.Join(workDir, strings.TrimPrefix(filepath.FromSlash(file.Path), string(filepath.Separator))+".readback")
if err := os.MkdirAll(filepath.Dir(readback), 0o755); err != nil {
return err
}
dumpOut, err := exec.Command("debugfs", "-R", fmt.Sprintf("dump %s %s", destPath, readback), fsPath).CombinedOutput()
if err != nil {
return fmt.Errorf("dump %s: %w: %s", destPath, err, string(dumpOut))
}
got, err := os.ReadFile(readback)
if err != nil {
return err
}
if !bytes.Equal(got, file.Content) {
return fmt.Errorf("verify %s content mismatch", destPath)
}
return nil
}
func parentDirs(path string) []string {
cleaned := "/" + strings.TrimPrefix(filepath.ToSlash(path), "/")
parts := strings.Split(cleaned, "/")
var dirs []string
for i := 2; i < len(parts); i++ {
dirs = append(dirs, strings.Join(parts[:i], "/"))
}
return dirs
}