diff --git a/ariadne/db/database.py b/ariadne/db/database.py index ba1b2b5..3582717 100644 --- a/ariadne/db/database.py +++ b/ariadne/db/database.py @@ -53,6 +53,45 @@ class Database: conn.row_factory = psycopg.rows.dict_row yield conn + def _configure_timeouts(self, conn: psycopg.Connection) -> None: + try: + conn.execute(f"SET lock_timeout = '{self._lock_timeout_sec}s'") + conn.execute(f"SET statement_timeout = '{self._statement_timeout_sec}s'") + except Exception: + pass + + def _try_advisory_lock(self, conn: psycopg.Connection, lock_id: int) -> bool: + row = conn.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,)).fetchone() + if isinstance(row, dict): + return bool(row.get("pg_try_advisory_lock")) + return bool(row and row[0]) + + def _apply_ariadne_tables(self, conn: psycopg.Connection) -> bool: + for stmt in ARIADNE_TABLES_SQL: + try: + conn.execute(stmt) + except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: + logger.warning("schema ensure skipped due to lock timeout: %s", exc) + return False + return True + + def _apply_access_requests(self, conn: psycopg.Connection) -> bool: + try: + conn.execute(ARIADNE_ACCESS_REQUEST_ALTER_SQL) + except psycopg.errors.UndefinedTable: + logger.info("access_requests table missing; skipping alter") + return True + except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: + logger.warning("schema ensure skipped due to lock timeout: %s", exc) + return False + return True + + def _unlock(self, conn: psycopg.Connection, lock_id: int) -> None: + try: + conn.execute("SELECT pg_advisory_unlock(%s)", (lock_id,)) + except Exception: + pass + def migrate( self, lock_id: int, @@ -61,39 +100,16 @@ class Database: include_access_requests: bool = True, ) -> None: with self.connection() as conn: - try: - conn.execute(f"SET lock_timeout = '{self._lock_timeout_sec}s'") - conn.execute(f"SET statement_timeout = '{self._statement_timeout_sec}s'") - except Exception: - pass - row = conn.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,)).fetchone() - if isinstance(row, dict): - locked = bool(row.get("pg_try_advisory_lock")) - else: - locked = bool(row and row[0]) - if not locked: + self._configure_timeouts(conn) + if not self._try_advisory_lock(conn, lock_id): return try: - if include_ariadne_tables: - for stmt in ARIADNE_TABLES_SQL: - try: - conn.execute(stmt) - except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: - logger.warning("schema ensure skipped due to lock timeout: %s", exc) - return - if include_access_requests: - try: - conn.execute(ARIADNE_ACCESS_REQUEST_ALTER_SQL) - except psycopg.errors.UndefinedTable: - logger.info("access_requests table missing; skipping alter") - except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: - logger.warning("schema ensure skipped due to lock timeout: %s", exc) - return + if include_ariadne_tables and not self._apply_ariadne_tables(conn): + return + if include_access_requests and not self._apply_access_requests(conn): + return finally: - try: - conn.execute("SELECT pg_advisory_unlock(%s)", (lock_id,)) - except Exception: - pass + self._unlock(conn, lock_id) def fetchone(self, query: str, params: Iterable[Any] | None = None) -> dict[str, Any] | None: with self.connection() as conn: