marimo-learn / bin /check_notebook_packages.py
Greg Wilson
feat: overhaul for relaunch
aaef24a
#!/usr/bin/env python
"""Check that marimo notebooks in the same lesson directory agree on package versions.
It is acceptable for different notebooks in a directory to specify different packages,
but if two or more notebooks specify the same package, their version constraints must
be identical.
"""
import argparse
import re
import sys
from collections import defaultdict
from pathlib import Path
# Regex to extract the inline script metadata block (PEP 723)
SCRIPT_BLOCK_RE = re.compile(r"^# /// script\s*\n((?:#[^\n]*\n)*?)# ///", re.MULTILINE)
DEPENDENCY_LINE_RE = re.compile(r'^#\s+"([^"]+)",?\s*$')
def parse_script_header(text: str) -> list[str]:
"""Return the list of dependency strings from a PEP 723 script header, or []."""
match = SCRIPT_BLOCK_RE.search(text)
if not match:
return []
block = match.group(1)
deps: list[str] = []
in_deps = False
for raw_line in block.splitlines():
line = raw_line.lstrip("#").strip()
if line.startswith("dependencies"):
in_deps = True
continue
if in_deps:
if line.startswith("]"):
break
# strip surrounding quotes and comma: e.g. ' "polars==1.0",' -> 'polars==1.0'
stripped = line.strip().strip('"\'').rstrip(",").strip('"\'')
if stripped:
deps.append(stripped)
return deps
def package_name(dep: str) -> str:
"""Extract the bare package name from a PEP 508 dependency string.
Examples:
"polars==1.22.0" -> "polars"
"pandas>=2.0,<3" -> "pandas"
"marimo" -> "marimo"
"""
return re.split(r"[><=!;\s\[]", dep, maxsplit=1)[0].lower()
def check_directory(lesson_dir: Path, only: set[str]) -> list[str]:
"""Return a list of error messages for version inconsistencies among *only* in lesson_dir."""
# Map package name -> {version_spec: [notebook_path, ...]}
seen: dict[str, dict[str, list[str]]] = defaultdict(lambda: defaultdict(list))
for nb in sorted(lesson_dir.glob("*.py")):
if nb.name not in only:
continue
try:
text = nb.read_text(encoding="utf-8")
except IOError:
continue
if "marimo.App" not in text:
continue
for dep in parse_script_header(text):
name = package_name(dep)
seen[name][dep].append(nb.name)
errors: list[str] = []
for name, specs in sorted(seen.items()):
if len(specs) > 1:
errors.append(f" Package '{name}' has conflicting specifications:")
for spec, files in sorted(specs.items()):
errors.append(f" {spec!r} in: {', '.join(files)}")
return errors
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("notebooks", nargs="+", metavar="NOTEBOOK",
help="notebook files to check (grouped by directory)")
args = parser.parse_args()
dir_filter: dict[Path, set[str]] = defaultdict(set)
for nb_path in (Path(p) for p in args.notebooks):
dir_filter[nb_path.parent].add(nb_path.name)
total_errors = 0
for lesson_dir, only in sorted(dir_filter.items()):
errors = check_directory(lesson_dir, only=only)
if errors:
print(f"\n{lesson_dir}/")
for msg in errors:
print(msg)
total_errors += len(errors)
if total_errors:
print(f"\nFound package version inconsistencies in {total_errors} package(s).")
sys.exit(1)
else:
print("All package version specifications are consistent.")
sys.exit(0)
if __name__ == "__main__":
main()