diff --git a/services/comms/scripts/guest-register/server.py b/services/comms/scripts/guest-register/server.py index 0e1fb4c..9daa971 100644 --- a/services/comms/scripts/guest-register/server.py +++ b/services/comms/scripts/guest-register/server.py @@ -3,6 +3,7 @@ import json import os import random import secrets +import time from http.server import BaseHTTPRequestHandler, HTTPServer from urllib import error, parse, request @@ -29,6 +30,20 @@ NOUN = [ "pine","quartz","reef","ridge","sable","sage","shore","thunder","vale","zephyr", ] +def _open_with_retry(req, timeout, attempts=6): + last = None + for attempt in range(1, attempts + 1): + try: + return request.urlopen(req, timeout=timeout) + except error.HTTPError as e: + return e + except (error.URLError, TimeoutError, OSError) as e: + last = e + time.sleep(attempt * 2) + if last: + raise last + raise RuntimeError("request_failed") + def _json(method, url, *, headers=None, body=None, timeout=20): hdrs = {"Content-Type": "application/json"} if headers: @@ -37,18 +52,17 @@ def _json(method, url, *, headers=None, body=None, timeout=20): if body is not None: data = json.dumps(body).encode() req = request.Request(url, data=data, headers=hdrs, method=method) - try: - with request.urlopen(req, timeout=timeout) as resp: - raw = resp.read() - payload = json.loads(raw.decode()) if raw else {} - return resp.status, payload - except error.HTTPError as e: - raw = e.read() + resp = _open_with_retry(req, timeout) + if isinstance(resp, error.HTTPError): + raw = resp.read() try: payload = json.loads(raw.decode()) if raw else {} except Exception: payload = {} - return e.code, payload + return resp.code, payload + raw = resp.read() + payload = json.loads(raw.decode()) if raw else {} + return resp.status, payload def _form(method, url, *, headers=None, fields=None, timeout=20): hdrs = {"Content-Type": "application/x-www-form-urlencoded"} @@ -56,18 +70,17 @@ def _form(method, url, *, headers=None, fields=None, timeout=20): hdrs.update(headers) data = parse.urlencode(fields or {}).encode() req = request.Request(url, data=data, headers=hdrs, method=method) - try: - with request.urlopen(req, timeout=timeout) as resp: - raw = resp.read() - payload = json.loads(raw.decode()) if raw else {} - return resp.status, payload - except error.HTTPError as e: - raw = e.read() + resp = _open_with_retry(req, timeout) + if isinstance(resp, error.HTTPError): + raw = resp.read() try: payload = json.loads(raw.decode()) if raw else {} except Exception: payload = {} - return e.code, payload + return resp.code, payload + raw = resp.read() + payload = json.loads(raw.decode()) if raw else {} + return resp.status, payload _admin_token = None _admin_token_at = 0.0 @@ -110,12 +123,28 @@ def _admin_api(admin_token, method, path, body=None): timeout=20, ) -def _create_user(admin_token, username): - status, payload = _admin_api(admin_token, "POST", "/users", {"username": username}) - if status != 201: - return status, None - user = payload.get("data") or {} - return status, user.get("id") +def _create_user(admin_token, username, password): + payloads = [ + { + "data": { + "type": "user", + "attributes": { + "username": username, + "password": password, + }, + } + }, + {"username": username, "password": password}, + {"username": username}, + ] + for payload in payloads: + status, body = _admin_api(admin_token, "POST", "/users", payload) + if status in (200, 201): + user = body.get("data") or {} + return status, user.get("id") or user.get("id") + if status == 409: + return status, None + return status, None def _set_password(admin_token, user_id, password): status, _payload = _admin_api( @@ -127,20 +156,28 @@ def _set_password(admin_token, user_id, password): return status in (200, 204) def _login_password(username, password): - payload = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": f"@{username}:{SERVER_NAME}"}, - "password": password, - } - status, data = _json( - "POST", - f"{MAS_BASE}/_matrix/client/v3/login", - body=payload, - timeout=20, - ) - if status != 200: - return None, None - return data.get("access_token"), data.get("device_id") + payloads = [ + { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": f"@{username}:{SERVER_NAME}"}, + "password": password, + }, + { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": username}, + "password": password, + }, + ] + for payload in payloads: + status, data = _json( + "POST", + f"{MAS_BASE}/_matrix/client/v3/login", + body=payload, + timeout=20, + ) + if status == 200: + return data.get("access_token"), data.get("device_id") + return None, None def _set_display_name(access_token, user_id, displayname): _json( @@ -224,18 +261,18 @@ class Handler(BaseHTTPRequestHandler): admin_token = _mas_admin_access_token(now) displayname = _generate_displayname() + password = secrets.token_urlsafe(18) localpart = None mas_user_id = None for _ in range(5): localpart = _generate_localpart() - status, mas_user_id = _create_user(admin_token, localpart) - if status == 201 and mas_user_id: + status, mas_user_id = _create_user(admin_token, localpart, password) + if status in (200, 201) and mas_user_id: break mas_user_id = None if not mas_user_id or not localpart: raise RuntimeError("add_user_failed") - password = secrets.token_urlsafe(18) if not _set_password(admin_token, mas_user_id, password): raise RuntimeError("set_password_failed") access_token, device_id = _login_password(localpart, password)