soteria/internal/server/server_test.go

182 lines
6.4 KiB
Go

package server
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"scm.bstein.dev/bstein/soteria/internal/api"
"scm.bstein.dev/bstein/soteria/internal/config"
"scm.bstein.dev/bstein/soteria/internal/k8s"
"scm.bstein.dev/bstein/soteria/internal/longhorn"
corev1 "k8s.io/api/core/v1"
)
type fakeKubeClient struct {
pvcs []k8s.PVCSummary
targetExists bool
}
func (f *fakeKubeClient) ResolvePVCVolume(_ context.Context, namespace, pvcName string) (string, *corev1.PersistentVolumeClaim, *corev1.PersistentVolume, error) {
return namespace + "-" + pvcName + "-pv", nil, nil, nil
}
func (f *fakeKubeClient) CreateBackupJob(_ context.Context, _ *config.Config, _ api.BackupRequest) (string, string, error) {
return "backup-job", "backup-secret", nil
}
func (f *fakeKubeClient) CreateRestoreJob(_ context.Context, _ *config.Config, _ api.RestoreTestRequest) (string, string, error) {
return "restore-job", "restore-secret", nil
}
func (f *fakeKubeClient) ListBoundPVCs(_ context.Context) ([]k8s.PVCSummary, error) {
return f.pvcs, nil
}
func (f *fakeKubeClient) PersistentVolumeClaimExists(_ context.Context, _, _ string) (bool, error) {
return f.targetExists, nil
}
type fakeLonghornClient struct {
backups []longhorn.Backup
}
func (f *fakeLonghornClient) SnapshotBackup(_ context.Context, volume, name string, labels map[string]string, backupMode string) (*longhorn.Volume, error) {
return &longhorn.Volume{Name: volume}, nil
}
func (f *fakeLonghornClient) GetVolume(_ context.Context, volume string) (*longhorn.Volume, error) {
return &longhorn.Volume{Name: volume, Size: "1073741824", NumberOfReplicas: 2}, nil
}
func (f *fakeLonghornClient) CreateVolumeFromBackup(_ context.Context, name, size string, replicas int, backupURL string) (*longhorn.Volume, error) {
return &longhorn.Volume{Name: name, Size: size, NumberOfReplicas: replicas}, nil
}
func (f *fakeLonghornClient) CreatePVC(_ context.Context, volumeName, namespace, pvcName string) error {
return nil
}
func (f *fakeLonghornClient) DeleteVolume(_ context.Context, volumeName string) error {
return nil
}
func (f *fakeLonghornClient) FindBackup(_ context.Context, volumeName, snapshot string) (*longhorn.Backup, error) {
return &longhorn.Backup{Name: "backup-latest", URL: "s3://bucket/backup-latest", State: "Completed"}, nil
}
func (f *fakeLonghornClient) ListBackups(_ context.Context, volumeName string) ([]longhorn.Backup, error) {
return f.backups, nil
}
func TestProtectedInventoryRequiresAuth(t *testing.T) {
srv := &Server{
cfg: &config.Config{AuthRequired: true, AllowedGroups: []string{"admin", "maintenance"}, BackupDriver: "longhorn"},
client: &fakeKubeClient{},
longhorn: &fakeLonghornClient{},
metrics: newTelemetry(),
}
srv.handler = http.HandlerFunc(srv.route)
req := httptest.NewRequest(http.MethodGet, "/v1/inventory", nil)
res := httptest.NewRecorder()
srv.Handler().ServeHTTP(res, req)
if res.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", res.Code)
}
}
func TestProtectedInventoryAllowsMaintenanceGroup(t *testing.T) {
srv := &Server{
cfg: &config.Config{AuthRequired: true, AllowedGroups: []string{"admin", "maintenance"}, BackupDriver: "longhorn", BackupMaxAge: 24 * time.Hour},
client: &fakeKubeClient{pvcs: []k8s.PVCSummary{{Namespace: "apps", Name: "data", VolumeName: "pv-apps-data", Phase: "Bound"}}},
longhorn: &fakeLonghornClient{backups: []longhorn.Backup{{Name: "backup-1", Created: "2026-04-12T00:00:00Z", State: "Completed", URL: "s3://bucket/backup-1"}}},
metrics: newTelemetry(),
}
srv.handler = http.HandlerFunc(srv.route)
req := httptest.NewRequest(http.MethodGet, "/v1/inventory", nil)
req.Header.Set("X-Auth-Request-User", "brad")
req.Header.Set("X-Auth-Request-Groups", "/maintenance")
res := httptest.NewRecorder()
srv.Handler().ServeHTTP(res, req)
if res.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", res.Code, res.Body.String())
}
var payload api.InventoryResponse
if err := json.NewDecoder(strings.NewReader(res.Body.String())).Decode(&payload); err != nil {
t.Fatalf("decode inventory: %v", err)
}
if len(payload.Namespaces) != 1 || payload.Namespaces[0].Name != "apps" {
t.Fatalf("unexpected inventory payload: %#v", payload)
}
}
func TestProtectedInventoryAllowsForwardedHeaders(t *testing.T) {
srv := &Server{
cfg: &config.Config{AuthRequired: true, AllowedGroups: []string{"admin", "maintenance"}, BackupDriver: "longhorn", BackupMaxAge: 24 * time.Hour},
client: &fakeKubeClient{pvcs: []k8s.PVCSummary{{Namespace: "apps", Name: "data", VolumeName: "pv-apps-data", Phase: "Bound"}}},
longhorn: &fakeLonghornClient{backups: []longhorn.Backup{{Name: "backup-1", Created: "2026-04-12T00:00:00Z", State: "Completed", URL: "s3://bucket/backup-1"}}},
metrics: newTelemetry(),
}
srv.handler = http.HandlerFunc(srv.route)
req := httptest.NewRequest(http.MethodGet, "/v1/inventory", nil)
req.Header.Set("X-Forwarded-User", "brad")
req.Header.Set("X-Forwarded-Groups", "/ops;/maintenance")
res := httptest.NewRecorder()
srv.Handler().ServeHTTP(res, req)
if res.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", res.Code, res.Body.String())
}
}
func TestRestoreRejectsExistingTargetPVC(t *testing.T) {
srv := &Server{
cfg: &config.Config{AuthRequired: false, BackupDriver: "longhorn"},
client: &fakeKubeClient{targetExists: true},
longhorn: &fakeLonghornClient{},
metrics: newTelemetry(),
}
srv.handler = http.HandlerFunc(srv.route)
body := `{"namespace":"apps","pvc":"data","target_namespace":"apps","target_pvc":"restore-data"}`
req := httptest.NewRequest(http.MethodPost, "/v1/restores", strings.NewReader(body))
res := httptest.NewRecorder()
srv.Handler().ServeHTTP(res, req)
if res.Code != http.StatusConflict {
t.Fatalf("expected 409, got %d: %s", res.Code, res.Body.String())
}
}
func TestMetricsStayPublic(t *testing.T) {
srv := &Server{
cfg: &config.Config{AuthRequired: true, AllowedGroups: []string{"admin"}},
client: &fakeKubeClient{},
longhorn: &fakeLonghornClient{},
metrics: newTelemetry(),
}
srv.handler = http.HandlerFunc(srv.route)
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
res := httptest.NewRecorder()
srv.Handler().ServeHTTP(res, req)
if res.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", res.Code)
}
if !strings.Contains(res.Body.String(), "soteria_backup_requests_total") {
t.Fatalf("expected prometheus metrics body, got %q", res.Body.String())
}
}