comms: harden guest register provisioning

This commit is contained in:
Brad Stein 2026-01-17 16:51:40 -03:00
parent f15b80872e
commit 931e41a76f

View File

@ -3,6 +3,7 @@ import json
import os import os
import random import random
import secrets import secrets
import time
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib import error, parse, request from urllib import error, parse, request
@ -29,6 +30,20 @@ NOUN = [
"pine","quartz","reef","ridge","sable","sage","shore","thunder","vale","zephyr", "pine","quartz","reef","ridge","sable","sage","shore","thunder","vale","zephyr",
] ]
def _open_with_retry(req, timeout, attempts=6):
last = None
for attempt in range(1, attempts + 1):
try:
return request.urlopen(req, timeout=timeout)
except error.HTTPError as e:
return e
except (error.URLError, TimeoutError, OSError) as e:
last = e
time.sleep(attempt * 2)
if last:
raise last
raise RuntimeError("request_failed")
def _json(method, url, *, headers=None, body=None, timeout=20): def _json(method, url, *, headers=None, body=None, timeout=20):
hdrs = {"Content-Type": "application/json"} hdrs = {"Content-Type": "application/json"}
if headers: if headers:
@ -37,18 +52,17 @@ def _json(method, url, *, headers=None, body=None, timeout=20):
if body is not None: if body is not None:
data = json.dumps(body).encode() data = json.dumps(body).encode()
req = request.Request(url, data=data, headers=hdrs, method=method) req = request.Request(url, data=data, headers=hdrs, method=method)
try: resp = _open_with_retry(req, timeout)
with request.urlopen(req, timeout=timeout) as resp: if isinstance(resp, error.HTTPError):
raw = resp.read() raw = resp.read()
payload = json.loads(raw.decode()) if raw else {}
return resp.status, payload
except error.HTTPError as e:
raw = e.read()
try: try:
payload = json.loads(raw.decode()) if raw else {} payload = json.loads(raw.decode()) if raw else {}
except Exception: except Exception:
payload = {} payload = {}
return e.code, payload return resp.code, payload
raw = resp.read()
payload = json.loads(raw.decode()) if raw else {}
return resp.status, payload
def _form(method, url, *, headers=None, fields=None, timeout=20): def _form(method, url, *, headers=None, fields=None, timeout=20):
hdrs = {"Content-Type": "application/x-www-form-urlencoded"} hdrs = {"Content-Type": "application/x-www-form-urlencoded"}
@ -56,18 +70,17 @@ def _form(method, url, *, headers=None, fields=None, timeout=20):
hdrs.update(headers) hdrs.update(headers)
data = parse.urlencode(fields or {}).encode() data = parse.urlencode(fields or {}).encode()
req = request.Request(url, data=data, headers=hdrs, method=method) req = request.Request(url, data=data, headers=hdrs, method=method)
try: resp = _open_with_retry(req, timeout)
with request.urlopen(req, timeout=timeout) as resp: if isinstance(resp, error.HTTPError):
raw = resp.read() raw = resp.read()
payload = json.loads(raw.decode()) if raw else {}
return resp.status, payload
except error.HTTPError as e:
raw = e.read()
try: try:
payload = json.loads(raw.decode()) if raw else {} payload = json.loads(raw.decode()) if raw else {}
except Exception: except Exception:
payload = {} payload = {}
return e.code, payload return resp.code, payload
raw = resp.read()
payload = json.loads(raw.decode()) if raw else {}
return resp.status, payload
_admin_token = None _admin_token = None
_admin_token_at = 0.0 _admin_token_at = 0.0
@ -110,12 +123,28 @@ def _admin_api(admin_token, method, path, body=None):
timeout=20, timeout=20,
) )
def _create_user(admin_token, username): def _create_user(admin_token, username, password):
status, payload = _admin_api(admin_token, "POST", "/users", {"username": username}) payloads = [
if status != 201: {
return status, None "data": {
user = payload.get("data") or {} "type": "user",
return status, user.get("id") "attributes": {
"username": username,
"password": password,
},
}
},
{"username": username, "password": password},
{"username": username},
]
for payload in payloads:
status, body = _admin_api(admin_token, "POST", "/users", payload)
if status in (200, 201):
user = body.get("data") or {}
return status, user.get("id") or user.get("id")
if status == 409:
return status, None
return status, None
def _set_password(admin_token, user_id, password): def _set_password(admin_token, user_id, password):
status, _payload = _admin_api( status, _payload = _admin_api(
@ -127,20 +156,28 @@ def _set_password(admin_token, user_id, password):
return status in (200, 204) return status in (200, 204)
def _login_password(username, password): def _login_password(username, password):
payload = { payloads = [
"type": "m.login.password", {
"identifier": {"type": "m.id.user", "user": f"@{username}:{SERVER_NAME}"}, "type": "m.login.password",
"password": password, "identifier": {"type": "m.id.user", "user": f"@{username}:{SERVER_NAME}"},
} "password": password,
status, data = _json( },
"POST", {
f"{MAS_BASE}/_matrix/client/v3/login", "type": "m.login.password",
body=payload, "identifier": {"type": "m.id.user", "user": username},
timeout=20, "password": password,
) },
if status != 200: ]
return None, None for payload in payloads:
return data.get("access_token"), data.get("device_id") status, data = _json(
"POST",
f"{MAS_BASE}/_matrix/client/v3/login",
body=payload,
timeout=20,
)
if status == 200:
return data.get("access_token"), data.get("device_id")
return None, None
def _set_display_name(access_token, user_id, displayname): def _set_display_name(access_token, user_id, displayname):
_json( _json(
@ -224,18 +261,18 @@ class Handler(BaseHTTPRequestHandler):
admin_token = _mas_admin_access_token(now) admin_token = _mas_admin_access_token(now)
displayname = _generate_displayname() displayname = _generate_displayname()
password = secrets.token_urlsafe(18)
localpart = None localpart = None
mas_user_id = None mas_user_id = None
for _ in range(5): for _ in range(5):
localpart = _generate_localpart() localpart = _generate_localpart()
status, mas_user_id = _create_user(admin_token, localpart) status, mas_user_id = _create_user(admin_token, localpart, password)
if status == 201 and mas_user_id: if status in (200, 201) and mas_user_id:
break break
mas_user_id = None mas_user_id = None
if not mas_user_id or not localpart: if not mas_user_id or not localpart:
raise RuntimeError("add_user_failed") raise RuntimeError("add_user_failed")
password = secrets.token_urlsafe(18)
if not _set_password(admin_token, mas_user_id, password): if not _set_password(admin_token, mas_user_id, password):
raise RuntimeError("set_password_failed") raise RuntimeError("set_password_failed")
access_token, device_id = _login_password(localpart, password) access_token, device_id = _login_password(localpart, password)