comms: fix synapse seed booleans

This commit is contained in:
Brad Stein 2026-01-08 05:00:58 -03:00
parent 28bcf716d0
commit c05cb414aa

View File

@ -2,7 +2,7 @@
apiVersion: batch/v1 apiVersion: batch/v1
kind: Job kind: Job
metadata: metadata:
name: synapse-user-seed-1 name: synapse-user-seed-2
namespace: comms namespace: comms
spec: spec:
backoffLimit: 1 backoffLimit: 1
@ -56,14 +56,18 @@ spec:
def get_cols(cur): def get_cols(cur):
cur.execute( cur.execute(
""" """
SELECT column_name, is_nullable, column_default SELECT column_name, is_nullable, column_default, data_type
FROM information_schema.columns FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = 'users' WHERE table_schema = 'public' AND table_name = 'users'
""" """
) )
cols = {} cols = {}
for name, is_nullable, default in cur.fetchall(): for name, is_nullable, default, data_type in cur.fetchall():
cols[name] = {"nullable": is_nullable == "YES", "default": default} cols[name] = {
"nullable": is_nullable == "YES",
"default": default,
"type": data_type,
}
return cols return cols
def upsert_user(cur, cols, user_id, password, admin): def upsert_user(cur, cols, user_id, password, admin):
@ -73,14 +77,18 @@ spec:
"password_hash": bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode(), "password_hash": bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode(),
"creation_ts": now_ms, "creation_ts": now_ms,
} }
if "admin" in cols: def add_flag(name, flag):
values["admin"] = admin if name not in cols:
if "deactivated" in cols: return
values["deactivated"] = False if cols[name]["type"] in ("smallint", "integer"):
if "shadow_banned" in cols: values[name] = int(flag)
values["shadow_banned"] = False else:
if "is_guest" in cols: values[name] = bool(flag)
values["is_guest"] = False
add_flag("admin", admin)
add_flag("deactivated", False)
add_flag("shadow_banned", False)
add_flag("is_guest", False)
columns = list(values.keys()) columns = list(values.keys())
placeholders = ", ".join(["%s"] * len(columns)) placeholders = ", ".join(["%s"] * len(columns))