package server import ( "context" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" "time" "scm.bstein.dev/bstein/soteria/internal/api" "scm.bstein.dev/bstein/soteria/internal/config" "scm.bstein.dev/bstein/soteria/internal/k8s" "scm.bstein.dev/bstein/soteria/internal/longhorn" corev1 "k8s.io/api/core/v1" ) type fakeKubeClient struct { pvcs []k8s.PVCSummary targetExists bool } func (f *fakeKubeClient) ResolvePVCVolume(_ context.Context, namespace, pvcName string) (string, *corev1.PersistentVolumeClaim, *corev1.PersistentVolume, error) { return namespace + "-" + pvcName + "-pv", nil, nil, nil } func (f *fakeKubeClient) CreateBackupJob(_ context.Context, _ *config.Config, _ api.BackupRequest) (string, string, error) { return "backup-job", "backup-secret", nil } func (f *fakeKubeClient) CreateRestoreJob(_ context.Context, _ *config.Config, _ api.RestoreTestRequest) (string, string, error) { return "restore-job", "restore-secret", nil } func (f *fakeKubeClient) ListBoundPVCs(_ context.Context) ([]k8s.PVCSummary, error) { return f.pvcs, nil } func (f *fakeKubeClient) PersistentVolumeClaimExists(_ context.Context, _, _ string) (bool, error) { return f.targetExists, nil } type fakeLonghornClient struct { backups []longhorn.Backup } func (f *fakeLonghornClient) SnapshotBackup(_ context.Context, volume, name string, labels map[string]string, backupMode string) (*longhorn.Volume, error) { return &longhorn.Volume{Name: volume}, nil } func (f *fakeLonghornClient) GetVolume(_ context.Context, volume string) (*longhorn.Volume, error) { return &longhorn.Volume{Name: volume, Size: "1073741824", NumberOfReplicas: 2}, nil } func (f *fakeLonghornClient) CreateVolumeFromBackup(_ context.Context, name, size string, replicas int, backupURL string) (*longhorn.Volume, error) { return &longhorn.Volume{Name: name, Size: size, NumberOfReplicas: replicas}, nil } func (f *fakeLonghornClient) CreatePVC(_ context.Context, volumeName, namespace, pvcName string) error { return nil } func (f *fakeLonghornClient) DeleteVolume(_ context.Context, volumeName string) error { return nil } func (f *fakeLonghornClient) FindBackup(_ context.Context, volumeName, snapshot string) (*longhorn.Backup, error) { return &longhorn.Backup{Name: "backup-latest", URL: "s3://bucket/backup-latest", State: "Completed"}, nil } func (f *fakeLonghornClient) ListBackups(_ context.Context, volumeName string) ([]longhorn.Backup, error) { return f.backups, nil } func TestProtectedInventoryRequiresAuth(t *testing.T) { srv := &Server{ cfg: &config.Config{AuthRequired: true, AllowedGroups: []string{"admin", "maintenance"}, BackupDriver: "longhorn"}, client: &fakeKubeClient{}, longhorn: &fakeLonghornClient{}, metrics: newTelemetry(), } srv.handler = http.HandlerFunc(srv.route) req := httptest.NewRequest(http.MethodGet, "/v1/inventory", nil) res := httptest.NewRecorder() srv.Handler().ServeHTTP(res, req) if res.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", res.Code) } } func TestProtectedInventoryAllowsMaintenanceGroup(t *testing.T) { srv := &Server{ cfg: &config.Config{AuthRequired: true, AllowedGroups: []string{"admin", "maintenance"}, BackupDriver: "longhorn", BackupMaxAge: 24 * time.Hour}, client: &fakeKubeClient{pvcs: []k8s.PVCSummary{{Namespace: "apps", Name: "data", VolumeName: "pv-apps-data", Phase: "Bound"}}}, longhorn: &fakeLonghornClient{backups: []longhorn.Backup{{Name: "backup-1", Created: "2026-04-12T00:00:00Z", State: "Completed", URL: "s3://bucket/backup-1"}}}, metrics: newTelemetry(), } srv.handler = http.HandlerFunc(srv.route) req := httptest.NewRequest(http.MethodGet, "/v1/inventory", nil) req.Header.Set("X-Auth-Request-User", "brad") req.Header.Set("X-Auth-Request-Groups", "/maintenance") res := httptest.NewRecorder() srv.Handler().ServeHTTP(res, req) if res.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", res.Code, res.Body.String()) } var payload api.InventoryResponse if err := json.NewDecoder(strings.NewReader(res.Body.String())).Decode(&payload); err != nil { t.Fatalf("decode inventory: %v", err) } if len(payload.Namespaces) != 1 || payload.Namespaces[0].Name != "apps" { t.Fatalf("unexpected inventory payload: %#v", payload) } } func TestProtectedInventoryAllowsForwardedHeaders(t *testing.T) { srv := &Server{ cfg: &config.Config{AuthRequired: true, AllowedGroups: []string{"admin", "maintenance"}, BackupDriver: "longhorn", BackupMaxAge: 24 * time.Hour}, client: &fakeKubeClient{pvcs: []k8s.PVCSummary{{Namespace: "apps", Name: "data", VolumeName: "pv-apps-data", Phase: "Bound"}}}, longhorn: &fakeLonghornClient{backups: []longhorn.Backup{{Name: "backup-1", Created: "2026-04-12T00:00:00Z", State: "Completed", URL: "s3://bucket/backup-1"}}}, metrics: newTelemetry(), } srv.handler = http.HandlerFunc(srv.route) req := httptest.NewRequest(http.MethodGet, "/v1/inventory", nil) req.Header.Set("X-Forwarded-User", "brad") req.Header.Set("X-Forwarded-Groups", "/ops;/maintenance") res := httptest.NewRecorder() srv.Handler().ServeHTTP(res, req) if res.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", res.Code, res.Body.String()) } } func TestRestoreRejectsExistingTargetPVC(t *testing.T) { srv := &Server{ cfg: &config.Config{AuthRequired: false, BackupDriver: "longhorn"}, client: &fakeKubeClient{targetExists: true}, longhorn: &fakeLonghornClient{}, metrics: newTelemetry(), } srv.handler = http.HandlerFunc(srv.route) body := `{"namespace":"apps","pvc":"data","target_namespace":"apps","target_pvc":"restore-data"}` req := httptest.NewRequest(http.MethodPost, "/v1/restores", strings.NewReader(body)) res := httptest.NewRecorder() srv.Handler().ServeHTTP(res, req) if res.Code != http.StatusConflict { t.Fatalf("expected 409, got %d: %s", res.Code, res.Body.String()) } } func TestMetricsStayPublic(t *testing.T) { srv := &Server{ cfg: &config.Config{AuthRequired: true, AllowedGroups: []string{"admin"}}, client: &fakeKubeClient{}, longhorn: &fakeLonghornClient{}, metrics: newTelemetry(), } srv.handler = http.HandlerFunc(srv.route) req := httptest.NewRequest(http.MethodGet, "/metrics", nil) res := httptest.NewRecorder() srv.Handler().ServeHTTP(res, req) if res.Code != http.StatusOK { t.Fatalf("expected 200, got %d", res.Code) } if !strings.Contains(res.Body.String(), "soteria_backup_requests_total") { t.Fatalf("expected prometheus metrics body, got %q", res.Body.String()) } }