#!/usr/bin/env python3 from __future__ import annotations import argparse import sys from collections import defaultdict from dataclasses import dataclass from typing import Any, Iterable from urllib.parse import quote import httpx from atlas_portal import db, settings from atlas_portal.keycloak import admin_client @dataclass(frozen=True) class KeycloakUser: id: str username: str @dataclass(frozen=True) class PortalRequest: request_code: str username: str status: str def _dedupe_by_id(users: Iterable[KeycloakUser]) -> list[KeycloakUser]: seen: set[str] = set() out: list[KeycloakUser] = [] for user in users: if user.id in seen: continue seen.add(user.id) out.append(user) return out def _iter_keycloak_users_for_prefix(prefix: str, max_results: int) -> list[KeycloakUser]: client = admin_client() if not client.ready(): raise RuntimeError("keycloak admin client not configured in this environment") url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users" # Keycloak can return false positives for search; we do a strict prefix match client-side. params = {"search": prefix, "max": str(max_results), "briefRepresentation": "true"} with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as http: resp = http.get(url, params=params, headers=client.headers()) resp.raise_for_status() payload = resp.json() if not isinstance(payload, list): return [] found: list[KeycloakUser] = [] for item in payload: if not isinstance(item, dict): continue username = item.get("username") user_id = item.get("id") if not isinstance(username, str) or not isinstance(user_id, str): continue if not username.startswith(prefix): continue if username.startswith("service-account-"): continue found.append(KeycloakUser(id=user_id, username=username)) return found def _find_keycloak_users(prefixes: list[str], max_results: int, protected: set[str]) -> list[KeycloakUser]: matches: list[KeycloakUser] = [] for prefix in prefixes: matches.extend(_iter_keycloak_users_for_prefix(prefix, max_results=max_results)) deduped = _dedupe_by_id(matches) return [user for user in deduped if user.username not in protected] def _delete_keycloak_users(users: list[KeycloakUser]) -> None: if not users: return client = admin_client() if not client.ready(): raise RuntimeError("keycloak admin client not configured in this environment") base = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users" with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as http: for user in users: url = f"{base}/{quote(user.id, safe='')}" resp = http.delete(url, headers=client.headers()) # Deleting a non-existent user is treated as success for idempotency. if resp.status_code == 404: continue resp.raise_for_status() def _find_portal_requests(prefixes: list[str], max_results: int) -> list[PortalRequest]: if not db.configured(): return [] like_prefixes = [f"{prefix}%" for prefix in prefixes] rows: list[dict[str, Any]] = [] with db.connect() as conn: for like in like_prefixes: cursor = conn.execute( """ SELECT request_code, username, status FROM access_requests WHERE username LIKE %s ORDER BY created_at DESC LIMIT %s """, (like, max_results), ) batch = cursor.fetchall() if isinstance(batch, list): rows.extend([r for r in batch if isinstance(r, dict)]) out: list[PortalRequest] = [] for row in rows: request_code = row.get("request_code") username = row.get("username") status = row.get("status") if not isinstance(request_code, str) or not isinstance(username, str) or not isinstance(status, str): continue out.append(PortalRequest(request_code=request_code, username=username, status=status)) return out def _delete_portal_requests(prefixes: list[str]) -> int: if not db.configured(): return 0 like_prefixes = [f"{prefix}%" for prefix in prefixes] deleted = 0 with db.connect() as conn: for like in like_prefixes: cursor = conn.execute("DELETE FROM access_requests WHERE username LIKE %s", (like,)) deleted += cursor.rowcount or 0 return deleted def _summarize_portal_requests(rows: list[PortalRequest]) -> dict[str, int]: counts: dict[str, int] = defaultdict(int) for row in rows: counts[row.status] += 1 return dict(counts) def _parse_args(argv: list[str]) -> argparse.Namespace: parser = argparse.ArgumentParser( prog="test_user_cleanup", description=( "Manual-only cleanup for test users/requests. " "This script is intended to be run inside the bstein-dev-home backend container." ), ) parser.add_argument( "--prefix", action="append", required=True, help="Username prefix to target (repeatable). Example: --prefix test-", ) parser.add_argument( "--max", type=int, default=500, help="Maximum users/requests to enumerate per prefix (default: 500).", ) parser.add_argument( "--apply", action="store_true", help="Apply deletions (default is dry-run). Requires --confirm.", ) parser.add_argument( "--confirm", default="", help="Required when using --apply. Must exactly equal the comma-separated prefix list.", ) parser.add_argument( "--skip-keycloak", action="store_true", help="Skip deleting Keycloak users.", ) parser.add_argument( "--skip-portal", action="store_true", help="Skip deleting portal (DB) access requests.", ) parser.add_argument( "--protect", action="append", default=[], help="Extra usernames to never delete (repeatable).", ) parser.add_argument( "--verbose", action="store_true", help="List matched usernames/request codes.", ) return parser.parse_args(argv) def main(argv: list[str]) -> int: args = _parse_args(argv) prefixes = sorted({p.strip() for p in args.prefix if p.strip()}) if not prefixes: print("error: no valid --prefix values provided", file=sys.stderr) return 2 expected_confirm = ",".join(prefixes) protected = {"bstein", "robotuser", *[p.strip() for p in args.protect if p.strip()]} if args.apply and args.confirm != expected_confirm: print( f"error: refusing to apply without --confirm '{expected_confirm}' (got '{args.confirm}')", file=sys.stderr, ) return 2 keycloak_users: list[KeycloakUser] = [] portal_requests: list[PortalRequest] = [] if not args.skip_keycloak: keycloak_users = _find_keycloak_users(prefixes, max_results=args.max, protected=protected) if not args.skip_portal: portal_requests = _find_portal_requests(prefixes, max_results=args.max) print(f"prefixes: {expected_confirm}") print(f"mode: {'APPLY' if args.apply else 'DRY-RUN'}") if protected: print(f"protected usernames: {', '.join(sorted(protected))}") if not args.skip_keycloak: print(f"keycloak users matched: {len(keycloak_users)}") if args.verbose and keycloak_users: for user in sorted(keycloak_users, key=lambda u: u.username): print(f" - {user.username}") if not args.skip_portal: print(f"portal requests matched: {len(portal_requests)}") if portal_requests: summary = _summarize_portal_requests(portal_requests) summary_str = ", ".join(f"{k}={v}" for k, v in sorted(summary.items())) print(f" statuses: {summary_str}") if args.verbose and portal_requests: for req in portal_requests[: min(50, len(portal_requests))]: print(f" - {req.request_code} ({req.status})") if len(portal_requests) > 50: print(f" ... and {len(portal_requests) - 50} more") if not args.apply: print("dry-run complete (no changes made)") return 0 if not args.skip_portal: deleted = _delete_portal_requests(prefixes) print(f"deleted portal requests: {deleted}") if not args.skip_keycloak: _delete_keycloak_users(keycloak_users) print(f"deleted keycloak users: {len(keycloak_users)}") print("done") return 0 if __name__ == "__main__": raise SystemExit(main(sys.argv[1:]))