Pulmo — Two-Stage Explainable Lung-Nodule Analysis

Pulmo is a lightweight, explainable pipeline for chest-CT lung-nodule analysis. It chains two models so you can go from a raw CT volume all the way to per-nodule diagnoses with clinical explanations:

  1. Stage 1 — Detector (HeatmapUNet3D): finds nodule centres in a full CT volume (3D sliding window → centre-probability heatmap → peaks).
  2. Stage 2 — Characteriser (Student2p5D): for each candidate, from a single 64³ patch (its 7 central axial slices) it jointly predicts
    • Detection — nodule vs. non-nodule (also filters Stage-1 false positives)
    • Malignancy — benign vs. malignant, via a concept bottleneck
    • 8 radiological concepts — subtlety, internal structure, calcification, sphericity, margin, lobulation, spiculation, texture
    • Segmentation — nodule mask of the central slice

Because malignancy is computed as Linear(8 concepts → 2), every malignancy prediction is fully attributable to the 8 clinical concepts — you can read off exactly which concept (e.g. spiculation) drove the decision.

⚠️ Research use only. Pulmo is not a medical device and must not be used for clinical diagnosis.

How it was built

Stage 2 is the deployment student of a knowledge-distillation pipeline:

  1. A 3D teacherUNet3D trunk (CNN-only) with concept-bottleneck multi-task heads — was trained on LUNA16/LIDC with focal loss, MixUp and aggressive augmentation. Teacher test: det 0.998 / mal 0.986 / Dice 0.857.
  2. Stage 2 — a 2.5D student — was trained by online distillation (loss = 0.5·hard + 0.5·soft, temperature 3.0) to imitate the frozen teacher, for ~5–10× faster inference at a fraction of the size, with the multi-task metrics preserved (see table).

Stage 1 is a separate 3D centre-detector (HeatmapUNet3D, base=16) trained with a CenterNet-style penalty-reduced focal loss to output a nodule-centre heatmap, then peak-picked + clustered into 3D candidate coordinates. It is kept at full 3D (already lightweight at ~23 MB); distilling it gave no useful size/speed win, so the 3D detector is shipped as-is.

Full training notebooks (data prep → labels → patch precompute → concepts → teacher → distillation → evaluation → explainability → Stage-1 detector): [link to your notebooks repo here]

Results (held-out internal split)

Stage 2 — characterisation (patch-level test split):

Task Metric Pulmo (2.5D student) Teacher (3D)
Detection AUC 0.997 0.998
Malignancy AUC 0.986 0.986
Segmentation Dice 0.859 0.857

Stage 1 — detection (scan-level val split, FROC):

Metric Value
CPM (mean sensitivity @ 1/8…8 FP/scan) 0.629
Sensitivity @ 16 FP/scan 0.956
Mean centre distance 1.85 mm

Patient-level 80/10/10 split of LUNA16. Stage-2 metrics are patch-level; Stage-1 metrics are scan-level. The pipeline has not been externally validated.

Usage

Full scan (both stages)

import numpy as np
from analyze_scan import load_pipeline, analyze_scan

stage1, stage2, device = load_pipeline()

# volume: (Z, Y, X) raw HU; spacing: (sz, sy, sx) mm in [z, y, x] order
findings = analyze_scan(volume, spacing, stage1, stage2, device=device)

for f in findings:
    z, y, x = f["location_voxel"]
    print(z, y, x, f["malignancy_prob"], f["prediction"], f["top_reasons"])

Each finding includes the location, detection/malignancy probabilities, the central-slice segmentation mask, all 8 concept values, and the top concepts driving the malignancy decision. See analyze_scan.py.

Single patch (Stage 2 only)

If you already have a candidate location (your own detector, or LUNA16 candidates.csv), you can run Stage 2 alone — see inference_example.py.

import torch
from huggingface_hub import hf_hub_download
from modeling import load_stage2, crop_stage2_input, explain_malignancy

model = load_stage2(hf_hub_download("ariyul/Pulmo", "student_2p5d_best.pth"))
x = crop_stage2_input(patch_3d, (32, 32, 32))     # 64^3 raw-HU patch -> (1, 7, 64, 64)
with torch.no_grad():
    out = model(x)
mal_p = torch.softmax(out["malignancy"][0], 0)[1].item()
print(explain_malignancy(model, out))             # concept-level explanation

Input / preprocessing

  • HU clip [-1000, 1000], then normalize to [0, 1] (identical for both stages).
  • Stage 1: raw (Z, Y, X) HU volume; processed as sliding-window 3D patches of [64, 128, 128] at native resolution.
  • Stage 2: the 7 central axial slices of a 64³ patch centred on a candidate → (B, 7, 64, 64).
  • Spacing is (sz, sy, sx) in mm, [z, y, x] order.

Files

  • stage1_detector_v2.pth — Stage-1 detector weights (HeatmapUNet3D)
  • student_2p5d_best.pth — Stage-2 characteriser weights (Student2p5D)
  • modeling.py — both model definitions + find_candidates, crop_stage2_input, explain_malignancy
  • analyze_scan.py — end-to-end pipeline (raw volume → findings)
  • inference_example.py — single-patch (Stage-2-only) example
  • config.json — architecture and preprocessing parameters

Training data & citations

Trained on LUNA16 (a curated subset of LIDC-IDRI). If you use Pulmo, please also credit the underlying datasets:

  • Setio et al., Validation, comparison, and combination of algorithms for automatic detection of pulmonary nodules in CT images: the LUNA16 challenge, Medical Image Analysis, 2017.
  • Armato et al., The Lung Image Database Consortium (LIDC) and Image Database Resource Initiative (IDRI), Medical Physics, 2011.

Limitations

  • Single internal split; no external/multi-center validation.
  • Trained on LUNA16 preprocessing conventions (resampling, HU window); behavior on other acquisition protocols is untested.
  • Stage-1 operating point trades recall against false positives (peak_thresh); the pipeline relies on Stage 2 to reject Stage-1 false positives.
  • Concept predictions are learned regressions of LIDC radiologist ratings, not ground-truth measurements.

License

Model weights and code: CC BY 4.0. Underlying datasets carry their own licenses.

Downloads last month
36
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support