#!/usr/bin/env python3 """Require docstrings on public production APIs.""" from __future__ import annotations import argparse import ast from pathlib import Path def _is_dataclass_class(node: ast.ClassDef) -> bool: """Return whether a class uses the dataclass decorator.""" return any( (isinstance(dec, ast.Name) and dec.id == "dataclass") or (isinstance(dec, ast.Call) and isinstance(dec.func, ast.Name) and dec.func.id == "dataclass") for dec in node.decorator_list ) def _base_names(node: ast.ClassDef) -> set[str]: """Return simple base class names used by a class definition.""" return {base.id for base in node.bases if isinstance(base, ast.Name)} def _needs_function_docstring(node: ast.FunctionDef | ast.AsyncFunctionDef, parent_class: str | None) -> bool: """Return whether a public function-like node needs a docstring.""" if node.name.startswith("_") and node.name != "__init__": return False return not (parent_class and node.name.startswith("_")) def _needs_class_docstring(node: ast.ClassDef) -> bool: """Return whether a public class-like node needs a docstring.""" bases = _base_names(node) skipped_bases = {"Exception", "RuntimeError", "BaseException", "BaseModel"} return not (node.name.startswith("_") or _is_dataclass_class(node) or bool(bases.intersection(skipped_bases))) def _needs_docstring(node: ast.AST, *, parent_class: str | None = None) -> bool: """Return whether `node` should carry an API contract docstring.""" if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): return _needs_function_docstring(node, parent_class) if isinstance(node, ast.ClassDef): return _needs_class_docstring(node) return False def _iter_nodes(tree: ast.AST) -> list[tuple[ast.AST, str | None]]: """Yield top-level surface area nodes for contract checking.""" return [(node, None) for node in getattr(tree, "body", [])] def main() -> int: """Scan the production package and fail on missing docstrings.""" parser = argparse.ArgumentParser() parser.add_argument("--root", default="ariadne") args = parser.parse_args() root = Path(args.root) violations: list[str] = [] for path in sorted(root.rglob("*.py")): if "__pycache__" in path.parts or ".venv" in path.parts: continue tree = ast.parse(path.read_text(encoding="utf-8")) for node, parent_class in _iter_nodes(tree): if not _needs_docstring(node, parent_class=parent_class): continue if ast.get_docstring(node): continue if isinstance(node, ast.ClassDef): violations.append(f"{path}: class {node.name} is missing a docstring") elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): owner = f"{parent_class}." if parent_class else "" violations.append(f"{path}: {owner}{node.name} is missing a docstring") if violations: for item in violations: print(item) return 1 return 0 if __name__ == "__main__": raise SystemExit(main())