titan-iac/services/communication/guest-register-configmap.yaml

130 lines
5.3 KiB
YAML

# services/communication/guest-register-configmap.yaml
apiVersion: v1
kind: ConfigMap
metadata:
name: matrix-guest-register
data:
server.py: |
import json
import os
from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib import error, parse, request
SYNAPSE_BASE = os.environ.get("SYNAPSE_BASE", "http://othrys-synapse-matrix-synapse:8008").rstrip("/")
GUEST_REGISTER_SHARED_SECRET = os.environ["GUEST_REGISTER_SHARED_SECRET"]
GUEST_REGISTER_HEADER = os.environ.get("GUEST_REGISTER_HEADER", "x-guest-register-secret")
GUEST_REGISTER_PATH = os.environ.get("GUEST_REGISTER_PATH", "/_matrix/_guest_register")
RATE_WINDOW_SEC = int(os.environ.get("RATE_WINDOW_SEC", "60"))
RATE_MAX = int(os.environ.get("RATE_MAX", "30"))
_rate = {} # ip -> [window_start, count]
def _json(method, url, *, headers=None, body=None, timeout=20):
hdrs = {"Content-Type": "application/json"}
if headers:
hdrs.update(headers)
data = None
if body is not None:
data = json.dumps(body).encode()
req = request.Request(url, data=data, headers=hdrs, method=method)
try:
with request.urlopen(req, timeout=timeout) as resp:
raw = resp.read()
payload = json.loads(raw.decode()) if raw else {}
return resp.status, payload
except error.HTTPError as e:
raw = e.read()
try:
payload = json.loads(raw.decode()) if raw else {}
except Exception:
payload = {}
return e.code, payload
def _rate_check(ip, now):
win, cnt = _rate.get(ip, (now, 0))
if now - win > RATE_WINDOW_SEC:
_rate[ip] = (now, 1)
return True
if cnt >= RATE_MAX:
return False
_rate[ip] = (win, cnt + 1)
return True
class Handler(BaseHTTPRequestHandler):
server_version = "matrix-guest-register"
def _send_json(self, code, payload):
body = json.dumps(payload).encode()
self.send_response(code)
self.send_header("Content-Type", "application/json")
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
self.send_header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def do_OPTIONS(self): # noqa: N802
self.send_response(204)
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
self.send_header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
self.end_headers()
def do_GET(self): # noqa: N802
parsed = parse.urlparse(self.path)
if parsed.path in ("/healthz", "/"):
return self._send_json(200, {"ok": True})
if parsed.path in ("/_matrix/client/v3/register", "/_matrix/client/r0/register"):
return self._send_json(200, {"flows": [{"stages": []}]})
return self._send_json(404, {"errcode": "M_NOT_FOUND", "error": "not_found"})
def do_POST(self): # noqa: N802
parsed = parse.urlparse(self.path)
if parsed.path not in ("/_matrix/client/v3/register", "/_matrix/client/r0/register"):
return self._send_json(404, {"errcode": "M_NOT_FOUND", "error": "not_found"})
qs = parse.parse_qs(parsed.query)
kind = (qs.get("kind") or ["user"])[0]
if kind != "guest":
return self._send_json(
403,
{
"errcode": "M_FORBIDDEN",
"error": "Registration is disabled; use https://bstein.dev/request-access for accounts.",
},
)
xfwd = self.headers.get("x-forwarded-for", "")
ip = (xfwd.split(",")[0].strip() if xfwd else "") or self.client_address[0]
now = __import__("time").time()
if not _rate_check(ip, now):
return self._send_json(429, {"errcode": "M_LIMIT_EXCEEDED", "error": "rate_limited"})
length = int(self.headers.get("content-length", "0") or "0")
raw = self.rfile.read(length) if length else b"{}"
try:
body = json.loads(raw.decode()) if raw else {}
if not isinstance(body, dict):
body = {}
except Exception:
body = {}
status, payload = _json(
"POST",
f"{SYNAPSE_BASE}{GUEST_REGISTER_PATH}",
headers={GUEST_REGISTER_HEADER: GUEST_REGISTER_SHARED_SECRET},
body=body,
timeout=20,
)
if "refresh_token" in payload:
payload.pop("refresh_token", None)
return self._send_json(status, payload)
def main():
port = int(os.environ.get("PORT", "8080"))
HTTPServer(("0.0.0.0", port), Handler).serve_forever()
if __name__ == "__main__":
main()