package image import ( "bytes" "encoding/json" "fmt" "io" "os" "os/exec" "path/filepath" "sort" "strings" "metis/pkg/inject" ) type partitionTable struct { PartitionTable partitionTableData `json:"partitiontable"` } type partitionTableData struct { SectorSize uint64 `json:"sectorsize"` Partitions []partitionTablePart `json:"partitions"` } type partitionTablePart struct { Start uint64 `json:"start"` Size uint64 `json:"size"` Type string `json:"type"` } // RootFSProgressUpdate carries coarse-grained step changes and optional byte // counters while Metis rewrites a Linux root filesystem inside a raw image. type RootFSProgressUpdate struct { Step string WrittenBytes int64 TotalBytes int64 } // RootFSProgressFunc receives RootFS progress updates during image rewriting. type RootFSProgressFunc func(update RootFSProgressUpdate) const ( RootFSProgressFindingPartition = "finding-partition" RootFSProgressExtracting = "extracting-partition" RootFSProgressWritingFiles = "writing-rootfs-files" RootFSProgressReplacing = "replacing-partition" ) // InjectRootFS rewrites the Linux root partition inside a raw image file without // requiring block-device mounts. Only rootfs-targeted files are written. func InjectRootFS(imagePath string, files []inject.FileSpec) error { return InjectRootFSWithProgress(imagePath, files, nil) } // InjectRootFSWithProgress emits coarse step changes while rewriting the root partition. func InjectRootFSWithProgress(imagePath string, files []inject.FileSpec, progress RootFSProgressFunc) error { rootFiles := make([]inject.FileSpec, 0, len(files)) for _, f := range files { if f.RootFS { rootFiles = append(rootFiles, f) } } if len(rootFiles) == 0 { return nil } emitRootFSProgress(progress, RootFSProgressUpdate{Step: RootFSProgressFindingPartition}) part, sectorSize, err := findLinuxPartition(imagePath) if err != nil { return err } workDir, err := mkdirTempNearPath(imagePath, "metis-rootfs-") if err != nil { return err } defer os.RemoveAll(workDir) rootImage := filepath.Join(workDir, "root.ext4") emitRootFSProgress(progress, RootFSProgressUpdate{Step: RootFSProgressExtracting}) if err := extractPartition(imagePath, rootImage, part, sectorSize); err != nil { return err } emitRootFSProgress(progress, RootFSProgressUpdate{Step: RootFSProgressWritingFiles}) if err := writeExt4Files(rootImage, rootFiles); err != nil { return err } emitRootFSProgress(progress, RootFSProgressUpdate{Step: RootFSProgressReplacing}) return replacePartition(imagePath, rootImage, part, sectorSize, func(written, total int64) { emitRootFSProgress(progress, RootFSProgressUpdate{ Step: RootFSProgressReplacing, WrittenBytes: written, TotalBytes: total, }) }) } func emitRootFSProgress(progress RootFSProgressFunc, update RootFSProgressUpdate) { if progress != nil { progress(update) } } func findLinuxPartition(imagePath string) (partitionTablePart, uint64, error) { out, err := exec.Command("sfdisk", "-J", imagePath).Output() if err != nil { return partitionTablePart{}, 0, fmt.Errorf("sfdisk -J %s: %w", imagePath, err) } var table partitionTable if err := json.Unmarshal(out, &table); err != nil { return partitionTablePart{}, 0, fmt.Errorf("decode partition table: %w", err) } sectorSize := table.PartitionTable.SectorSize if sectorSize == 0 { sectorSize = 512 } for i := len(table.PartitionTable.Partitions) - 1; i >= 0; i-- { part := table.PartitionTable.Partitions[i] if isLinuxPartitionType(part.Type) { return part, sectorSize, nil } } return partitionTablePart{}, 0, fmt.Errorf("no Linux root partition found in %s", imagePath) } func isLinuxPartitionType(partType string) bool { normalized := strings.ToLower(strings.TrimSpace(partType)) switch normalized { case "83", "8300": return true } return normalized == "0fc63daf-8483-4772-8e79-3d69d8477de4" } func extractPartition(imagePath, outPath string, part partitionTablePart, sectorSize uint64) error { sizeBytes := int64(part.Size * sectorSize) offsetBytes := int64(part.Start * sectorSize) src, err := os.Open(imagePath) if err != nil { return err } defer src.Close() if _, err := src.Seek(offsetBytes, io.SeekStart); err != nil { return err } out, err := os.Create(outPath) if err != nil { return err } defer out.Close() if _, err := io.CopyN(out, src, sizeBytes); err != nil { return fmt.Errorf("extract root partition: %w", err) } return out.Sync() } func replacePartition(imagePath, rootImage string, part partitionTablePart, sectorSize uint64, progress func(written, total int64)) error { expectedSize := int64(part.Size * sectorSize) info, err := os.Stat(rootImage) if err != nil { return err } if info.Size() != expectedSize { return fmt.Errorf("root partition size mismatch: expected %d got %d", expectedSize, info.Size()) } in, err := os.Open(rootImage) if err != nil { return err } defer in.Close() out, err := os.OpenFile(imagePath, os.O_WRONLY, 0) if err != nil { return err } defer out.Close() if _, err := out.Seek(int64(part.Start*sectorSize), io.SeekStart); err != nil { return err } if _, err := copyWithProgress(out, in, expectedSize, progress); err != nil { return fmt.Errorf("write root partition: %w", err) } return out.Sync() } func writeExt4Files(fsPath string, files []inject.FileSpec) error { workDir, err := mkdirTempNearPath(fsPath, "metis-ext4-") if err != nil { return err } defer os.RemoveAll(workDir) stageDir := filepath.Join(workDir, "stage") commandFile := filepath.Join(workDir, "commands.txt") dirs := map[string]struct{}{} commands := make([]string, 0, len(files)*4) for _, f := range files { localPath := filepath.Join(stageDir, filepath.FromSlash(f.Path)) if err := os.MkdirAll(filepath.Dir(localPath), 0o755); err != nil { return err } if err := os.WriteFile(localPath, f.Content, 0o644); err != nil { return err } for _, dir := range parentDirs(f.Path) { dirs[dir] = struct{}{} } } dirList := make([]string, 0, len(dirs)) for dir := range dirs { dirList = append(dirList, dir) } sort.Slice(dirList, func(i, j int) bool { leftDepth := strings.Count(dirList[i], "/") rightDepth := strings.Count(dirList[j], "/") if leftDepth != rightDepth { return leftDepth < rightDepth } return dirList[i] < dirList[j] }) for _, dir := range dirList { commands = append(commands, fmt.Sprintf("mkdir %s", dir)) } for _, f := range files { destPath := "/" + strings.TrimPrefix(filepath.ToSlash(f.Path), "/") localPath := filepath.Join(stageDir, filepath.FromSlash(f.Path)) commands = append(commands, fmt.Sprintf("rm %s", destPath)) commands = append(commands, fmt.Sprintf("write %s %s", localPath, destPath)) commands = append(commands, fmt.Sprintf("sif %s mode 0%o", destPath, uint32(0o100000|f.Mode.Perm()))) } if err := os.WriteFile(commandFile, []byte(strings.Join(commands, "\n")+"\n"), 0o644); err != nil { return err } cmd := exec.Command("debugfs", "-w", "-f", commandFile, fsPath) var combined bytes.Buffer cmd.Stdout = &combined cmd.Stderr = &combined if err := cmd.Run(); err != nil { return fmt.Errorf("debugfs write failed: %w: %s", err, combined.String()) } for _, f := range files { if err := verifyExt4File(fsPath, f, workDir); err != nil { return err } } return nil } func mkdirTempNearPath(targetPath, pattern string) (string, error) { parent := strings.TrimSpace(os.Getenv("METIS_ROOTFS_TMP_DIR")) if parent == "" && strings.TrimSpace(targetPath) != "" { parent = filepath.Join(filepath.Dir(targetPath), ".metis-tmp") } if parent == "" { return os.MkdirTemp("", pattern) } if err := os.MkdirAll(parent, 0o755); err != nil { return "", err } return os.MkdirTemp(parent, pattern) } func copyWithProgress(dst io.Writer, src io.Reader, total int64, progress func(written, total int64)) (int64, error) { buf := make([]byte, 2*1024*1024) var written int64 for { nr, er := src.Read(buf) if nr > 0 { nw, ew := dst.Write(buf[:nr]) written += int64(nw) if progress != nil { progress(written, total) } if ew != nil { return written, ew } if nw != nr { return written, io.ErrShortWrite } } if er != nil { if er == io.EOF { return written, nil } return written, er } } } func verifyExt4File(fsPath string, file inject.FileSpec, workDir string) error { destPath := "/" + strings.TrimPrefix(filepath.ToSlash(file.Path), "/") statOut, err := exec.Command("debugfs", "-R", "stat "+destPath, fsPath).CombinedOutput() if err != nil { return fmt.Errorf("verify %s: %w: %s", destPath, err, string(statOut)) } expectedMode := fmt.Sprintf("Mode: %04o", file.Mode.Perm()) if !strings.Contains(string(statOut), expectedMode) { return fmt.Errorf("verify %s mode: expected %s in %s", destPath, expectedMode, string(statOut)) } readback := filepath.Join(workDir, strings.TrimPrefix(filepath.FromSlash(file.Path), string(filepath.Separator))+".readback") if err := os.MkdirAll(filepath.Dir(readback), 0o755); err != nil { return err } dumpOut, err := exec.Command("debugfs", "-R", fmt.Sprintf("dump %s %s", destPath, readback), fsPath).CombinedOutput() if err != nil { return fmt.Errorf("dump %s: %w: %s", destPath, err, string(dumpOut)) } got, err := os.ReadFile(readback) if err != nil { return err } if !bytes.Equal(got, file.Content) { return fmt.Errorf("verify %s content mismatch", destPath) } return nil } func parentDirs(path string) []string { cleaned := "/" + strings.TrimPrefix(filepath.ToSlash(path), "/") parts := strings.Split(cleaned, "/") var dirs []string for i := 2; i < len(parts); i++ { dirs = append(dirs, strings.Join(parts[:i], "/")) } return dirs }