fix: route portal access to portal db
This commit is contained in:
parent
a228e063f1
commit
e72beb89bd
@ -62,7 +62,7 @@ class PasswordResetRequest:
|
|||||||
|
|
||||||
portal_db = Database(settings.portal_database_url)
|
portal_db = Database(settings.portal_database_url)
|
||||||
ariadne_db = Database(settings.ariadne_database_url)
|
ariadne_db = Database(settings.ariadne_database_url)
|
||||||
storage = Storage(ariadne_db)
|
storage = Storage(ariadne_db, portal_db)
|
||||||
provisioning = ProvisioningManager(portal_db, storage)
|
provisioning = ProvisioningManager(portal_db, storage)
|
||||||
scheduler = CronScheduler(storage, settings.schedule_tick_sec)
|
scheduler = CronScheduler(storage, settings.schedule_tick_sec)
|
||||||
|
|
||||||
|
|||||||
@ -60,8 +60,9 @@ class ScheduleState:
|
|||||||
|
|
||||||
|
|
||||||
class Storage:
|
class Storage:
|
||||||
def __init__(self, db: Database) -> None:
|
def __init__(self, db: Database, portal_db: Database | None = None) -> None:
|
||||||
self._db = db
|
self._db = db
|
||||||
|
self._portal_db = portal_db or db
|
||||||
|
|
||||||
def ensure_task_rows(self, request_code: str, tasks: Iterable[str]) -> None:
|
def ensure_task_rows(self, request_code: str, tasks: Iterable[str]) -> None:
|
||||||
tasks_list = list(tasks)
|
tasks_list = list(tasks)
|
||||||
@ -109,7 +110,7 @@ class Storage:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def fetch_access_request(self, request_code: str) -> AccessRequest | None:
|
def fetch_access_request(self, request_code: str) -> AccessRequest | None:
|
||||||
row = self._db.fetchone(
|
row = self._portal_db.fetchone(
|
||||||
"""
|
"""
|
||||||
SELECT request_code, username, contact_email, status, email_verified_at,
|
SELECT request_code, username, contact_email, status, email_verified_at,
|
||||||
initial_password, initial_password_revealed_at, provision_attempted_at,
|
initial_password, initial_password_revealed_at, provision_attempted_at,
|
||||||
@ -124,7 +125,7 @@ class Storage:
|
|||||||
return self._row_to_request(row)
|
return self._row_to_request(row)
|
||||||
|
|
||||||
def find_access_request_by_username(self, username: str) -> AccessRequest | None:
|
def find_access_request_by_username(self, username: str) -> AccessRequest | None:
|
||||||
row = self._db.fetchone(
|
row = self._portal_db.fetchone(
|
||||||
"""
|
"""
|
||||||
SELECT request_code, username, contact_email, status, email_verified_at,
|
SELECT request_code, username, contact_email, status, email_verified_at,
|
||||||
initial_password, initial_password_revealed_at, provision_attempted_at,
|
initial_password, initial_password_revealed_at, provision_attempted_at,
|
||||||
@ -141,7 +142,7 @@ class Storage:
|
|||||||
return self._row_to_request(row)
|
return self._row_to_request(row)
|
||||||
|
|
||||||
def list_pending_requests(self) -> list[dict[str, Any]]:
|
def list_pending_requests(self) -> list[dict[str, Any]]:
|
||||||
return self._db.fetchall(
|
return self._portal_db.fetchall(
|
||||||
"""
|
"""
|
||||||
SELECT request_code, username, contact_email, note, status, created_at
|
SELECT request_code, username, contact_email, note, status, created_at
|
||||||
FROM access_requests
|
FROM access_requests
|
||||||
@ -152,7 +153,7 @@ class Storage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def list_provision_candidates(self) -> list[AccessRequest]:
|
def list_provision_candidates(self) -> list[AccessRequest]:
|
||||||
rows = self._db.fetchall(
|
rows = self._portal_db.fetchall(
|
||||||
"""
|
"""
|
||||||
SELECT request_code, username, contact_email, status, email_verified_at,
|
SELECT request_code, username, contact_email, status, email_verified_at,
|
||||||
initial_password, initial_password_revealed_at, provision_attempted_at,
|
initial_password, initial_password_revealed_at, provision_attempted_at,
|
||||||
@ -166,19 +167,19 @@ class Storage:
|
|||||||
return [self._row_to_request(row) for row in rows]
|
return [self._row_to_request(row) for row in rows]
|
||||||
|
|
||||||
def update_status(self, request_code: str, status: str) -> None:
|
def update_status(self, request_code: str, status: str) -> None:
|
||||||
self._db.execute(
|
self._portal_db.execute(
|
||||||
"UPDATE access_requests SET status = %s WHERE request_code = %s",
|
"UPDATE access_requests SET status = %s WHERE request_code = %s",
|
||||||
(status, request_code),
|
(status, request_code),
|
||||||
)
|
)
|
||||||
|
|
||||||
def mark_provision_attempted(self, request_code: str) -> None:
|
def mark_provision_attempted(self, request_code: str) -> None:
|
||||||
self._db.execute(
|
self._portal_db.execute(
|
||||||
"UPDATE access_requests SET provision_attempted_at = NOW() WHERE request_code = %s",
|
"UPDATE access_requests SET provision_attempted_at = NOW() WHERE request_code = %s",
|
||||||
(request_code,),
|
(request_code,),
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_initial_password(self, request_code: str, password: str) -> None:
|
def set_initial_password(self, request_code: str, password: str) -> None:
|
||||||
self._db.execute(
|
self._portal_db.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE access_requests
|
UPDATE access_requests
|
||||||
SET initial_password = %s
|
SET initial_password = %s
|
||||||
@ -188,7 +189,7 @@ class Storage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def mark_welcome_sent(self, request_code: str) -> None:
|
def mark_welcome_sent(self, request_code: str) -> None:
|
||||||
self._db.execute(
|
self._portal_db.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE access_requests
|
UPDATE access_requests
|
||||||
SET welcome_email_sent_at = NOW()
|
SET welcome_email_sent_at = NOW()
|
||||||
@ -198,7 +199,7 @@ class Storage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update_approval(self, request_code: str, status: str, decided_by: str, flags: list[str], note: str | None) -> None:
|
def update_approval(self, request_code: str, status: str, decided_by: str, flags: list[str], note: str | None) -> None:
|
||||||
self._db.execute(
|
self._portal_db.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE access_requests
|
UPDATE access_requests
|
||||||
SET status = %s,
|
SET status = %s,
|
||||||
|
|||||||
@ -95,6 +95,10 @@ def _domain_matches(email: str) -> bool:
|
|||||||
return email.lower().endswith(f"@{settings.mailu_domain.lower()}")
|
return email.lower().endswith(f"@{settings.mailu_domain.lower()}")
|
||||||
|
|
||||||
|
|
||||||
|
def _password_too_long(password: str) -> bool:
|
||||||
|
return len(password.encode("utf-8")) > 72
|
||||||
|
|
||||||
|
|
||||||
class MailuService:
|
class MailuService:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._db_config = {
|
self._db_config = {
|
||||||
@ -163,6 +167,7 @@ class MailuService:
|
|||||||
|
|
||||||
def _prepare_updates(
|
def _prepare_updates(
|
||||||
self,
|
self,
|
||||||
|
username: str,
|
||||||
attrs: dict[str, Any],
|
attrs: dict[str, Any],
|
||||||
mailu_email: str,
|
mailu_email: str,
|
||||||
) -> tuple[bool, dict[str, list[str]], str]:
|
) -> tuple[bool, dict[str, list[str]], str]:
|
||||||
@ -176,6 +181,18 @@ class MailuService:
|
|||||||
if not app_password:
|
if not app_password:
|
||||||
app_password = random_password(24)
|
app_password = random_password(24)
|
||||||
updates[MAILU_APP_PASSWORD_ATTR] = [app_password]
|
updates[MAILU_APP_PASSWORD_ATTR] = [app_password]
|
||||||
|
elif _password_too_long(app_password):
|
||||||
|
app_password = random_password(24)
|
||||||
|
updates[MAILU_APP_PASSWORD_ATTR] = [app_password]
|
||||||
|
logger.info(
|
||||||
|
"mailu app password rotated",
|
||||||
|
extra={
|
||||||
|
"event": "mailu_sync",
|
||||||
|
"status": "updated",
|
||||||
|
"detail": "app password exceeded bcrypt limit",
|
||||||
|
"username": username,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return enabled, updates, app_password
|
return enabled, updates, app_password
|
||||||
|
|
||||||
@ -212,7 +229,7 @@ class MailuService:
|
|||||||
attrs,
|
attrs,
|
||||||
user.get("email") if isinstance(user.get("email"), str) else "",
|
user.get("email") if isinstance(user.get("email"), str) else "",
|
||||||
)
|
)
|
||||||
enabled, updates, app_password = self._prepare_updates(attrs, mailu_email)
|
enabled, updates, app_password = self._prepare_updates(username, attrs, mailu_email)
|
||||||
|
|
||||||
if not enabled:
|
if not enabled:
|
||||||
return MailuUserSyncResult(skipped=1)
|
return MailuUserSyncResult(skipped=1)
|
||||||
@ -247,6 +264,8 @@ class MailuService:
|
|||||||
return False
|
return False
|
||||||
if not _domain_matches(email):
|
if not _domain_matches(email):
|
||||||
return False
|
return False
|
||||||
|
if _password_too_long(password):
|
||||||
|
raise ValueError("mailu password exceeds bcrypt limit")
|
||||||
|
|
||||||
localpart, domain = email.split("@", 1)
|
localpart, domain = email.split("@", 1)
|
||||||
hashed = bcrypt_sha256.hash(password)
|
hashed = bcrypt_sha256.hash(password)
|
||||||
@ -298,6 +317,12 @@ class MailuService:
|
|||||||
extra={"event": "mailu_sync", "status": "error", "detail": "system password missing"},
|
extra={"event": "mailu_sync", "status": "error", "detail": "system password missing"},
|
||||||
)
|
)
|
||||||
return 0
|
return 0
|
||||||
|
if _password_too_long(settings.mailu_system_password):
|
||||||
|
logger.info(
|
||||||
|
"mailu system password too long",
|
||||||
|
extra={"event": "mailu_sync", "status": "error", "detail": "system password exceeds bcrypt limit"},
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
ensured = 0
|
ensured = 0
|
||||||
for email in settings.mailu_system_users:
|
for email in settings.mailu_system_users:
|
||||||
|
|||||||
@ -389,6 +389,68 @@ def test_mailu_sync_updates_attrs(monkeypatch) -> None:
|
|||||||
assert "mailu_email" in updates[0][1]["attributes"]
|
assert "mailu_email" in updates[0][1]["attributes"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mailu_sync_rotates_long_password(monkeypatch) -> None:
|
||||||
|
long_password = "x" * 100
|
||||||
|
dummy_settings = types.SimpleNamespace(
|
||||||
|
mailu_domain="bstein.dev",
|
||||||
|
mailu_db_host="localhost",
|
||||||
|
mailu_db_port=5432,
|
||||||
|
mailu_db_name="mailu",
|
||||||
|
mailu_db_user="mailu",
|
||||||
|
mailu_db_password="secret",
|
||||||
|
mailu_default_quota=20000000000,
|
||||||
|
mailu_system_users=[],
|
||||||
|
mailu_system_password="",
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("ariadne.services.mailu.settings", dummy_settings)
|
||||||
|
monkeypatch.setattr("ariadne.services.mailu.keycloak_admin.ready", lambda: True)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ariadne.services.mailu.keycloak_admin.iter_users",
|
||||||
|
lambda *args, **kwargs: [
|
||||||
|
{
|
||||||
|
"id": "1",
|
||||||
|
"username": "alice",
|
||||||
|
"enabled": True,
|
||||||
|
"email": "alice@example.com",
|
||||||
|
"attributes": {"mailu_app_password": [long_password]},
|
||||||
|
"firstName": "Alice",
|
||||||
|
"lastName": "Example",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("ariadne.services.mailu.random_password", lambda *_args, **_kwargs: "short-pass-123")
|
||||||
|
|
||||||
|
updates: list[tuple[str, dict[str, object]]] = []
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ariadne.services.mailu.keycloak_admin.update_user_safe",
|
||||||
|
lambda user_id, payload: updates.append((user_id, payload)),
|
||||||
|
)
|
||||||
|
|
||||||
|
mailbox_calls: list[tuple[str, str, str]] = []
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ariadne.services.mailu.MailuService._ensure_mailbox",
|
||||||
|
lambda self, _conn, email, password, display: mailbox_calls.append((email, password, display)) or True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyConn:
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr("ariadne.services.mailu.psycopg.connect", lambda *args, **kwargs: DummyConn())
|
||||||
|
|
||||||
|
svc = MailuService()
|
||||||
|
summary = svc.sync("provision", force=True)
|
||||||
|
|
||||||
|
assert summary.processed == 1
|
||||||
|
assert updates
|
||||||
|
attrs = updates[0][1]["attributes"]
|
||||||
|
assert attrs["mailu_app_password"] == ["short-pass-123"]
|
||||||
|
assert mailbox_calls
|
||||||
|
assert mailbox_calls[0][1] == "short-pass-123"
|
||||||
|
|
||||||
def test_mailu_sync_skips_disabled(monkeypatch) -> None:
|
def test_mailu_sync_skips_disabled(monkeypatch) -> None:
|
||||||
dummy_settings = types.SimpleNamespace(
|
dummy_settings = types.SimpleNamespace(
|
||||||
mailu_domain="bstein.dev",
|
mailu_domain="bstein.dev",
|
||||||
|
|||||||
@ -53,6 +53,33 @@ def test_row_to_request_flags() -> None:
|
|||||||
assert req.approval_flags == ["demo", "1", "test"]
|
assert req.approval_flags == ["demo", "1", "test"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_access_requests_use_portal_db() -> None:
|
||||||
|
portal_row = {
|
||||||
|
"request_code": "req",
|
||||||
|
"username": "alice",
|
||||||
|
"contact_email": "a@example.com",
|
||||||
|
"status": "pending",
|
||||||
|
"email_verified_at": None,
|
||||||
|
"initial_password": None,
|
||||||
|
"initial_password_revealed_at": None,
|
||||||
|
"provision_attempted_at": None,
|
||||||
|
"approval_flags": [],
|
||||||
|
"approval_note": None,
|
||||||
|
"denial_note": None,
|
||||||
|
}
|
||||||
|
db = DummyDB()
|
||||||
|
portal = DummyDB(row=portal_row)
|
||||||
|
portal.rows = [{"request_code": "req"}]
|
||||||
|
storage = Storage(db, portal)
|
||||||
|
|
||||||
|
rows = storage.list_pending_requests()
|
||||||
|
assert rows == portal.rows
|
||||||
|
|
||||||
|
req = storage.fetch_access_request("req")
|
||||||
|
assert req is not None
|
||||||
|
assert req.request_code == "req"
|
||||||
|
|
||||||
|
|
||||||
def test_record_event_serializes_dict() -> None:
|
def test_record_event_serializes_dict() -> None:
|
||||||
db = DummyDB()
|
db = DummyDB()
|
||||||
storage = Storage(db)
|
storage = Storage(db)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user