258 lines
7.1 KiB
Go
258 lines
7.1 KiB
Go
|
|
package image
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"os"
|
||
|
|
"os/exec"
|
||
|
|
"path/filepath"
|
||
|
|
"sort"
|
||
|
|
"strings"
|
||
|
|
|
||
|
|
"metis/pkg/inject"
|
||
|
|
)
|
||
|
|
|
||
|
|
type partitionTable struct {
|
||
|
|
PartitionTable partitionTableData `json:"partitiontable"`
|
||
|
|
}
|
||
|
|
|
||
|
|
type partitionTableData struct {
|
||
|
|
SectorSize uint64 `json:"sectorsize"`
|
||
|
|
Partitions []partitionTablePart `json:"partitions"`
|
||
|
|
}
|
||
|
|
|
||
|
|
type partitionTablePart struct {
|
||
|
|
Start uint64 `json:"start"`
|
||
|
|
Size uint64 `json:"size"`
|
||
|
|
Type string `json:"type"`
|
||
|
|
}
|
||
|
|
|
||
|
|
// InjectRootFS rewrites the Linux root partition inside a raw image file without
|
||
|
|
// requiring block-device mounts. Only rootfs-targeted files are written.
|
||
|
|
func InjectRootFS(imagePath string, files []inject.FileSpec) error {
|
||
|
|
rootFiles := make([]inject.FileSpec, 0, len(files))
|
||
|
|
for _, f := range files {
|
||
|
|
if f.RootFS {
|
||
|
|
rootFiles = append(rootFiles, f)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if len(rootFiles) == 0 {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
part, sectorSize, err := findLinuxPartition(imagePath)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
workDir, err := os.MkdirTemp("", "metis-rootfs-")
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer os.RemoveAll(workDir)
|
||
|
|
|
||
|
|
rootImage := filepath.Join(workDir, "root.ext4")
|
||
|
|
if err := extractPartition(imagePath, rootImage, part, sectorSize); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if err := writeExt4Files(rootImage, rootFiles); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return replacePartition(imagePath, rootImage, part, sectorSize)
|
||
|
|
}
|
||
|
|
|
||
|
|
func findLinuxPartition(imagePath string) (partitionTablePart, uint64, error) {
|
||
|
|
out, err := exec.Command("sfdisk", "-J", imagePath).Output()
|
||
|
|
if err != nil {
|
||
|
|
return partitionTablePart{}, 0, fmt.Errorf("sfdisk -J %s: %w", imagePath, err)
|
||
|
|
}
|
||
|
|
var table partitionTable
|
||
|
|
if err := json.Unmarshal(out, &table); err != nil {
|
||
|
|
return partitionTablePart{}, 0, fmt.Errorf("decode partition table: %w", err)
|
||
|
|
}
|
||
|
|
sectorSize := table.PartitionTable.SectorSize
|
||
|
|
if sectorSize == 0 {
|
||
|
|
sectorSize = 512
|
||
|
|
}
|
||
|
|
for i := len(table.PartitionTable.Partitions) - 1; i >= 0; i-- {
|
||
|
|
part := table.PartitionTable.Partitions[i]
|
||
|
|
if isLinuxPartitionType(part.Type) {
|
||
|
|
return part, sectorSize, nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return partitionTablePart{}, 0, fmt.Errorf("no Linux root partition found in %s", imagePath)
|
||
|
|
}
|
||
|
|
|
||
|
|
func isLinuxPartitionType(partType string) bool {
|
||
|
|
normalized := strings.ToLower(strings.TrimSpace(partType))
|
||
|
|
switch normalized {
|
||
|
|
case "83", "8300":
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
return normalized == "0fc63daf-8483-4772-8e79-3d69d8477de4"
|
||
|
|
}
|
||
|
|
|
||
|
|
func extractPartition(imagePath, outPath string, part partitionTablePart, sectorSize uint64) error {
|
||
|
|
sizeBytes := int64(part.Size * sectorSize)
|
||
|
|
offsetBytes := int64(part.Start * sectorSize)
|
||
|
|
|
||
|
|
src, err := os.Open(imagePath)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer src.Close()
|
||
|
|
if _, err := src.Seek(offsetBytes, io.SeekStart); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
out, err := os.Create(outPath)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer out.Close()
|
||
|
|
if _, err := io.CopyN(out, src, sizeBytes); err != nil {
|
||
|
|
return fmt.Errorf("extract root partition: %w", err)
|
||
|
|
}
|
||
|
|
return out.Sync()
|
||
|
|
}
|
||
|
|
|
||
|
|
func replacePartition(imagePath, rootImage string, part partitionTablePart, sectorSize uint64) error {
|
||
|
|
expectedSize := int64(part.Size * sectorSize)
|
||
|
|
info, err := os.Stat(rootImage)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if info.Size() != expectedSize {
|
||
|
|
return fmt.Errorf("root partition size mismatch: expected %d got %d", expectedSize, info.Size())
|
||
|
|
}
|
||
|
|
|
||
|
|
in, err := os.Open(rootImage)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer in.Close()
|
||
|
|
|
||
|
|
out, err := os.OpenFile(imagePath, os.O_WRONLY, 0)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer out.Close()
|
||
|
|
if _, err := out.Seek(int64(part.Start*sectorSize), io.SeekStart); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if _, err := io.Copy(out, in); err != nil {
|
||
|
|
return fmt.Errorf("write root partition: %w", err)
|
||
|
|
}
|
||
|
|
return out.Sync()
|
||
|
|
}
|
||
|
|
|
||
|
|
func writeExt4Files(fsPath string, files []inject.FileSpec) error {
|
||
|
|
workDir, err := os.MkdirTemp("", "metis-ext4-")
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer os.RemoveAll(workDir)
|
||
|
|
|
||
|
|
stageDir := filepath.Join(workDir, "stage")
|
||
|
|
commandFile := filepath.Join(workDir, "commands.txt")
|
||
|
|
|
||
|
|
dirs := map[string]struct{}{}
|
||
|
|
commands := make([]string, 0, len(files)*4)
|
||
|
|
|
||
|
|
for _, f := range files {
|
||
|
|
localPath := filepath.Join(stageDir, filepath.FromSlash(f.Path))
|
||
|
|
if err := os.MkdirAll(filepath.Dir(localPath), 0o755); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if err := os.WriteFile(localPath, f.Content, 0o644); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
for _, dir := range parentDirs(f.Path) {
|
||
|
|
dirs[dir] = struct{}{}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
dirList := make([]string, 0, len(dirs))
|
||
|
|
for dir := range dirs {
|
||
|
|
dirList = append(dirList, dir)
|
||
|
|
}
|
||
|
|
sort.Slice(dirList, func(i, j int) bool {
|
||
|
|
leftDepth := strings.Count(dirList[i], "/")
|
||
|
|
rightDepth := strings.Count(dirList[j], "/")
|
||
|
|
if leftDepth != rightDepth {
|
||
|
|
return leftDepth < rightDepth
|
||
|
|
}
|
||
|
|
return dirList[i] < dirList[j]
|
||
|
|
})
|
||
|
|
for _, dir := range dirList {
|
||
|
|
commands = append(commands, fmt.Sprintf("mkdir %s", dir))
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, f := range files {
|
||
|
|
destPath := "/" + strings.TrimPrefix(filepath.ToSlash(f.Path), "/")
|
||
|
|
localPath := filepath.Join(stageDir, filepath.FromSlash(f.Path))
|
||
|
|
commands = append(commands, fmt.Sprintf("rm %s", destPath))
|
||
|
|
commands = append(commands, fmt.Sprintf("write %s %s", localPath, destPath))
|
||
|
|
commands = append(commands, fmt.Sprintf("sif %s mode 0%o", destPath, uint32(0o100000|f.Mode.Perm())))
|
||
|
|
}
|
||
|
|
if err := os.WriteFile(commandFile, []byte(strings.Join(commands, "\n")+"\n"), 0o644); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
cmd := exec.Command("debugfs", "-w", "-f", commandFile, fsPath)
|
||
|
|
var combined bytes.Buffer
|
||
|
|
cmd.Stdout = &combined
|
||
|
|
cmd.Stderr = &combined
|
||
|
|
if err := cmd.Run(); err != nil {
|
||
|
|
return fmt.Errorf("debugfs write failed: %w: %s", err, combined.String())
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, f := range files {
|
||
|
|
if err := verifyExt4File(fsPath, f, workDir); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func verifyExt4File(fsPath string, file inject.FileSpec, workDir string) error {
|
||
|
|
destPath := "/" + strings.TrimPrefix(filepath.ToSlash(file.Path), "/")
|
||
|
|
statOut, err := exec.Command("debugfs", "-R", "stat "+destPath, fsPath).CombinedOutput()
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("verify %s: %w: %s", destPath, err, string(statOut))
|
||
|
|
}
|
||
|
|
expectedMode := fmt.Sprintf("Mode: %04o", file.Mode.Perm())
|
||
|
|
if !strings.Contains(string(statOut), expectedMode) {
|
||
|
|
return fmt.Errorf("verify %s mode: expected %s in %s", destPath, expectedMode, string(statOut))
|
||
|
|
}
|
||
|
|
|
||
|
|
readback := filepath.Join(workDir, strings.TrimPrefix(filepath.FromSlash(file.Path), string(filepath.Separator))+".readback")
|
||
|
|
if err := os.MkdirAll(filepath.Dir(readback), 0o755); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
dumpOut, err := exec.Command("debugfs", "-R", fmt.Sprintf("dump %s %s", destPath, readback), fsPath).CombinedOutput()
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("dump %s: %w: %s", destPath, err, string(dumpOut))
|
||
|
|
}
|
||
|
|
got, err := os.ReadFile(readback)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if !bytes.Equal(got, file.Content) {
|
||
|
|
return fmt.Errorf("verify %s content mismatch", destPath)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func parentDirs(path string) []string {
|
||
|
|
cleaned := "/" + strings.TrimPrefix(filepath.ToSlash(path), "/")
|
||
|
|
parts := strings.Split(cleaned, "/")
|
||
|
|
var dirs []string
|
||
|
|
for i := 2; i < len(parts); i++ {
|
||
|
|
dirs = append(dirs, strings.Join(parts[:i], "/"))
|
||
|
|
}
|
||
|
|
return dirs
|
||
|
|
}
|