comms: harden guest register provisioning

This commit is contained in:
Brad Stein 2026-01-17 16:51:40 -03:00
parent f15b80872e
commit 931e41a76f

View File

@ -3,6 +3,7 @@ import json
import os
import random
import secrets
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib import error, parse, request
@ -29,6 +30,20 @@ NOUN = [
"pine","quartz","reef","ridge","sable","sage","shore","thunder","vale","zephyr",
]
def _open_with_retry(req, timeout, attempts=6):
last = None
for attempt in range(1, attempts + 1):
try:
return request.urlopen(req, timeout=timeout)
except error.HTTPError as e:
return e
except (error.URLError, TimeoutError, OSError) as e:
last = e
time.sleep(attempt * 2)
if last:
raise last
raise RuntimeError("request_failed")
def _json(method, url, *, headers=None, body=None, timeout=20):
hdrs = {"Content-Type": "application/json"}
if headers:
@ -37,18 +52,17 @@ def _json(method, url, *, headers=None, body=None, timeout=20):
if body is not None:
data = json.dumps(body).encode()
req = request.Request(url, data=data, headers=hdrs, method=method)
try:
with request.urlopen(req, timeout=timeout) as resp:
raw = resp.read()
payload = json.loads(raw.decode()) if raw else {}
return resp.status, payload
except error.HTTPError as e:
raw = e.read()
resp = _open_with_retry(req, timeout)
if isinstance(resp, error.HTTPError):
raw = resp.read()
try:
payload = json.loads(raw.decode()) if raw else {}
except Exception:
payload = {}
return e.code, payload
return resp.code, payload
raw = resp.read()
payload = json.loads(raw.decode()) if raw else {}
return resp.status, payload
def _form(method, url, *, headers=None, fields=None, timeout=20):
hdrs = {"Content-Type": "application/x-www-form-urlencoded"}
@ -56,18 +70,17 @@ def _form(method, url, *, headers=None, fields=None, timeout=20):
hdrs.update(headers)
data = parse.urlencode(fields or {}).encode()
req = request.Request(url, data=data, headers=hdrs, method=method)
try:
with request.urlopen(req, timeout=timeout) as resp:
raw = resp.read()
payload = json.loads(raw.decode()) if raw else {}
return resp.status, payload
except error.HTTPError as e:
raw = e.read()
resp = _open_with_retry(req, timeout)
if isinstance(resp, error.HTTPError):
raw = resp.read()
try:
payload = json.loads(raw.decode()) if raw else {}
except Exception:
payload = {}
return e.code, payload
return resp.code, payload
raw = resp.read()
payload = json.loads(raw.decode()) if raw else {}
return resp.status, payload
_admin_token = None
_admin_token_at = 0.0
@ -110,12 +123,28 @@ def _admin_api(admin_token, method, path, body=None):
timeout=20,
)
def _create_user(admin_token, username):
status, payload = _admin_api(admin_token, "POST", "/users", {"username": username})
if status != 201:
return status, None
user = payload.get("data") or {}
return status, user.get("id")
def _create_user(admin_token, username, password):
payloads = [
{
"data": {
"type": "user",
"attributes": {
"username": username,
"password": password,
},
}
},
{"username": username, "password": password},
{"username": username},
]
for payload in payloads:
status, body = _admin_api(admin_token, "POST", "/users", payload)
if status in (200, 201):
user = body.get("data") or {}
return status, user.get("id") or user.get("id")
if status == 409:
return status, None
return status, None
def _set_password(admin_token, user_id, password):
status, _payload = _admin_api(
@ -127,20 +156,28 @@ def _set_password(admin_token, user_id, password):
return status in (200, 204)
def _login_password(username, password):
payload = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": f"@{username}:{SERVER_NAME}"},
"password": password,
}
status, data = _json(
"POST",
f"{MAS_BASE}/_matrix/client/v3/login",
body=payload,
timeout=20,
)
if status != 200:
return None, None
return data.get("access_token"), data.get("device_id")
payloads = [
{
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": f"@{username}:{SERVER_NAME}"},
"password": password,
},
{
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": username},
"password": password,
},
]
for payload in payloads:
status, data = _json(
"POST",
f"{MAS_BASE}/_matrix/client/v3/login",
body=payload,
timeout=20,
)
if status == 200:
return data.get("access_token"), data.get("device_id")
return None, None
def _set_display_name(access_token, user_id, displayname):
_json(
@ -224,18 +261,18 @@ class Handler(BaseHTTPRequestHandler):
admin_token = _mas_admin_access_token(now)
displayname = _generate_displayname()
password = secrets.token_urlsafe(18)
localpart = None
mas_user_id = None
for _ in range(5):
localpart = _generate_localpart()
status, mas_user_id = _create_user(admin_token, localpart)
if status == 201 and mas_user_id:
status, mas_user_id = _create_user(admin_token, localpart, password)
if status in (200, 201) and mas_user_id:
break
mas_user_id = None
if not mas_user_id or not localpart:
raise RuntimeError("add_user_failed")
password = secrets.token_urlsafe(18)
if not _set_password(admin_token, mas_user_id, password):
raise RuntimeError("set_password_failed")
access_token, device_id = _login_password(localpart, password)