import base64 import json import os import random import secrets from http.server import BaseHTTPRequestHandler, HTTPServer from urllib import error, parse, request MAS_BASE = os.environ.get("MAS_BASE", "http://matrix-authentication-service:8080").rstrip("/") MAS_ADMIN_API_BASE = os.environ.get("MAS_ADMIN_API_BASE", "http://matrix-authentication-service:8081/api/admin/v1").rstrip("/") SYNAPSE_BASE = os.environ.get("SYNAPSE_BASE", "http://othrys-synapse-matrix-synapse:8008").rstrip("/") SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev") MAS_ADMIN_CLIENT_ID = os.environ["MAS_ADMIN_CLIENT_ID"] MAS_ADMIN_CLIENT_SECRET_FILE = os.environ.get("MAS_ADMIN_CLIENT_SECRET_FILE", "/etc/mas/admin-client/client_secret") MAS_ADMIN_SCOPE = os.environ.get("MAS_ADMIN_SCOPE", "urn:mas:admin") RATE_WINDOW_SEC = int(os.environ.get("RATE_WINDOW_SEC", "60")) RATE_MAX = int(os.environ.get("RATE_MAX", "30")) _rate = {} # ip -> [window_start, count] ADJ = [ "brisk","calm","eager","gentle","merry","nifty","rapid","sunny","witty","zesty", "amber","bold","bright","crisp","daring","frosty","glad","jolly","lively","mellow", "quiet","ripe","serene","spry","tidy","vivid","warm","wild","clever","kind", ] NOUN = [ "otter","falcon","comet","ember","grove","harbor","meadow","raven","river","summit", "breeze","cedar","cinder","cove","delta","forest","glade","lark","marsh","peak", "pine","quartz","reef","ridge","sable","sage","shore","thunder","vale","zephyr", ] def _json(method, url, *, headers=None, body=None, timeout=20): hdrs = {"Content-Type": "application/json"} if headers: hdrs.update(headers) data = None if body is not None: data = json.dumps(body).encode() req = request.Request(url, data=data, headers=hdrs, method=method) try: with request.urlopen(req, timeout=timeout) as resp: raw = resp.read() payload = json.loads(raw.decode()) if raw else {} return resp.status, payload except error.HTTPError as e: raw = e.read() try: payload = json.loads(raw.decode()) if raw else {} except Exception: payload = {} return e.code, payload def _form(method, url, *, headers=None, fields=None, timeout=20): hdrs = {"Content-Type": "application/x-www-form-urlencoded"} if headers: hdrs.update(headers) data = parse.urlencode(fields or {}).encode() req = request.Request(url, data=data, headers=hdrs, method=method) try: with request.urlopen(req, timeout=timeout) as resp: raw = resp.read() payload = json.loads(raw.decode()) if raw else {} return resp.status, payload except error.HTTPError as e: raw = e.read() try: payload = json.loads(raw.decode()) if raw else {} except Exception: payload = {} return e.code, payload _admin_token = None _admin_token_at = 0.0 def _mas_admin_access_token(now): global _admin_token, _admin_token_at if _admin_token and (now - _admin_token_at) < 300: return _admin_token with open(MAS_ADMIN_CLIENT_SECRET_FILE, encoding="utf-8") as fh: client_secret = fh.read().strip() basic = base64.b64encode(f"{MAS_ADMIN_CLIENT_ID}:{client_secret}".encode()).decode() status, payload = _form( "POST", f"{MAS_BASE}/oauth2/token", headers={"Authorization": f"Basic {basic}"}, fields={"grant_type": "client_credentials", "scope": MAS_ADMIN_SCOPE}, timeout=20, ) if status != 200 or "access_token" not in payload: raise RuntimeError("mas_admin_token_failed") _admin_token = payload["access_token"] _admin_token_at = now return _admin_token def _generate_localpart(): return "guest-" + secrets.token_hex(6) def _generate_displayname(): return f"{random.choice(ADJ)}-{random.choice(NOUN)}" def _admin_api(admin_token, method, path, body=None): return _json( method, f"{MAS_ADMIN_API_BASE}{path}", headers={"Authorization": f"Bearer {admin_token}"}, body=body, timeout=20, ) def _create_user(admin_token, username): status, payload = _admin_api(admin_token, "POST", "/users", {"username": username}) if status != 201: return status, None user = payload.get("data") or {} return status, user.get("id") def _set_password(admin_token, user_id, password): status, _payload = _admin_api( admin_token, "POST", f"/users/{parse.quote(user_id)}/set-password", {"password": password}, ) return status in (200, 204) def _login_password(username, password): payload = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": f"@{username}:{SERVER_NAME}"}, "password": password, } status, data = _json( "POST", f"{MAS_BASE}/_matrix/client/v3/login", body=payload, timeout=20, ) if status != 200: return None, None return data.get("access_token"), data.get("device_id") def _set_display_name(access_token, user_id, displayname): _json( "PUT", f"{SYNAPSE_BASE}/_matrix/client/v3/profile/{parse.quote(user_id, safe='')}/displayname", headers={"Authorization": f"Bearer {access_token}"}, body={"displayname": displayname}, timeout=20, ) def _rate_check(ip, now): win, cnt = _rate.get(ip, (now, 0)) if now - win > RATE_WINDOW_SEC: _rate[ip] = (now, 1) return True if cnt >= RATE_MAX: return False _rate[ip] = (win, cnt + 1) return True class Handler(BaseHTTPRequestHandler): server_version = "matrix-guest-register" def _send_json(self, code, payload): body = json.dumps(payload).encode() self.send_response(code) self.send_header("Content-Type", "application/json") self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") self.send_header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") self.send_header("Content-Length", str(len(body))) self.end_headers() self.wfile.write(body) def do_OPTIONS(self): # noqa: N802 self.send_response(204) self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") self.send_header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") self.end_headers() def do_GET(self): # noqa: N802 parsed = parse.urlparse(self.path) if parsed.path in ("/healthz", "/"): return self._send_json(200, {"ok": True}) if parsed.path in ("/_matrix/client/v3/register", "/_matrix/client/r0/register"): return self._send_json(200, {"flows": [{"stages": []}]}) return self._send_json(404, {"errcode": "M_NOT_FOUND", "error": "not_found"}) def do_POST(self): # noqa: N802 parsed = parse.urlparse(self.path) if parsed.path not in ("/_matrix/client/v3/register", "/_matrix/client/r0/register"): return self._send_json(404, {"errcode": "M_NOT_FOUND", "error": "not_found"}) qs = parse.parse_qs(parsed.query) kind = (qs.get("kind") or ["user"])[0] if kind != "guest": return self._send_json( 403, { "errcode": "M_FORBIDDEN", "error": "Registration is disabled; use https://bstein.dev/request-access for accounts.", }, ) xfwd = self.headers.get("x-forwarded-for", "") ip = (xfwd.split(",")[0].strip() if xfwd else "") or self.client_address[0] now = __import__("time").time() if not _rate_check(ip, now): return self._send_json(429, {"errcode": "M_LIMIT_EXCEEDED", "error": "rate_limited"}) length = int(self.headers.get("content-length", "0") or "0") raw = self.rfile.read(length) if length else b"{}" try: body = json.loads(raw.decode()) if raw else {} if not isinstance(body, dict): body = {} except Exception: body = {} try: admin_token = _mas_admin_access_token(now) displayname = _generate_displayname() localpart = None mas_user_id = None for _ in range(5): localpart = _generate_localpart() status, mas_user_id = _create_user(admin_token, localpart) if status == 201 and mas_user_id: break mas_user_id = None if not mas_user_id or not localpart: raise RuntimeError("add_user_failed") password = secrets.token_urlsafe(18) if not _set_password(admin_token, mas_user_id, password): raise RuntimeError("set_password_failed") access_token, device_id = _login_password(localpart, password) if not access_token: raise RuntimeError("login_failed") try: _set_display_name(access_token, f"@{localpart}:{SERVER_NAME}", displayname) except Exception: pass except Exception: return self._send_json(502, {"errcode": "M_UNKNOWN", "error": "guest_provision_failed"}) resp = { "user_id": f"@{localpart}:{SERVER_NAME}", "access_token": access_token, "device_id": device_id or "guest_device", "home_server": SERVER_NAME, } return self._send_json(200, resp) def main(): port = int(os.environ.get("PORT", "8080")) HTTPServer(("0.0.0.0", port), Handler).serve_forever() if __name__ == "__main__": main()