ariadne/scripts/check_docstrings.py

92 lines
3.1 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""Require docstrings on public production APIs."""
from __future__ import annotations
import argparse
import ast
from pathlib import Path
def _is_dataclass_class(node: ast.ClassDef) -> bool:
"""Return whether a class uses the dataclass decorator."""
return any(
(isinstance(dec, ast.Name) and dec.id == "dataclass")
or (isinstance(dec, ast.Call) and isinstance(dec.func, ast.Name) and dec.func.id == "dataclass")
for dec in node.decorator_list
)
def _base_names(node: ast.ClassDef) -> set[str]:
"""Return simple base class names used by a class definition."""
return {base.id for base in node.bases if isinstance(base, ast.Name)}
def _needs_function_docstring(node: ast.FunctionDef | ast.AsyncFunctionDef, parent_class: str | None) -> bool:
"""Return whether a public function-like node needs a docstring."""
if node.name.startswith("_") and node.name != "__init__":
return False
return not (parent_class and node.name.startswith("_"))
def _needs_class_docstring(node: ast.ClassDef) -> bool:
"""Return whether a public class-like node needs a docstring."""
bases = _base_names(node)
skipped_bases = {"Exception", "RuntimeError", "BaseException", "BaseModel"}
return not (node.name.startswith("_") or _is_dataclass_class(node) or bool(bases.intersection(skipped_bases)))
def _needs_docstring(node: ast.AST, *, parent_class: str | None = None) -> bool:
"""Return whether `node` should carry an API contract docstring."""
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
return _needs_function_docstring(node, parent_class)
if isinstance(node, ast.ClassDef):
return _needs_class_docstring(node)
return False
def _iter_nodes(tree: ast.AST) -> list[tuple[ast.AST, str | None]]:
"""Yield top-level surface area nodes for contract checking."""
return [(node, None) for node in getattr(tree, "body", [])]
def main() -> int:
"""Scan the production package and fail on missing docstrings."""
parser = argparse.ArgumentParser()
parser.add_argument("--root", default="ariadne")
args = parser.parse_args()
root = Path(args.root)
violations: list[str] = []
for path in sorted(root.rglob("*.py")):
if "__pycache__" in path.parts or ".venv" in path.parts:
continue
tree = ast.parse(path.read_text(encoding="utf-8"))
for node, parent_class in _iter_nodes(tree):
if not _needs_docstring(node, parent_class=parent_class):
continue
if ast.get_docstring(node):
continue
if isinstance(node, ast.ClassDef):
violations.append(f"{path}: class {node.name} is missing a docstring")
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
owner = f"{parent_class}." if parent_class else ""
violations.append(f"{path}: {owner}{node.name} is missing a docstring")
if violations:
for item in violations:
print(item)
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())