MidasMap / docs /MidasMap_Presentation.md
AnikS22's picture
Upload docs/MidasMap_Presentation.md with huggingface_hub
d1fe61c verified

MidasMap: Automated Immunogold Particle Detection for TEM Synapse Images


The Problem

Neuroscientists use immunogold labeling to visualize receptor proteins at synapses in transmission electron microscopy (TEM) images.

  • 6nm gold beads label AMPA receptors (panAMPA)
  • 12nm gold beads label NR1 (NMDA) receptors
  • 18nm gold beads label vGlut2 (vesicular glutamate transporter)

Manual counting is slow and subjective. Each image takes 30-60 minutes to annotate. With hundreds of synapses per experiment, this becomes a bottleneck.

The Challenge

  • Particles are tiny (4-10 pixels radius) on 2048x2115 images
  • Contrast delta is only 11-39 intensity units on a 0-255 scale
  • Large dark vesicles look similar to gold particles to naive detectors
  • Only 453 labeled particles across 10 training images

Previous Approaches (GoldDigger et al.)

Approach Result
CenterNet (initial attempt) "Detection quality remained poor"
U-Net heatmap Macro F1 = 0.005-0.017
GoldDigger/cGAN "No durable breakthrough"
Aggressive filtering "FP dropped but TP dropped harder"

Core issue: Previous systems failed due to:

  1. Incorrect coordinate conversion (microns treated as normalized values)
  2. Broken loss function (heatmap peaks not exactly 1.0)
  3. Overfitting to fixed training patches

MidasMap Architecture

Input: Raw TEM Image (any size)
         |
    [Sliding Window → 512x512 patches]
         |
    ResNet-50 Encoder (pretrained on CEM500K: 500K EM images)
         |
    BiFPN Neck (bidirectional feature pyramid, 2 rounds, 128ch)
         |
    Transposed Conv Decoder → stride-2 output
         |
    +------------------+-------------------+
    |                  |                   |
    Heatmap Head       Offset Head
    (2ch sigmoid)      (2ch regression)
    6nm channel        sub-pixel x,y
    12nm channel       correction
    |                  |
    +------------------+-------------------+
         |
    Peak Extraction (max-pool NMS)
         |
    Cross-class NMS + Mask Filter
         |
    Output: [(x, y, class, confidence), ...]

Key Design Decisions

CEM500K Backbone: ResNet-50 pretrained on 500,000 electron microscopy images via self-supervised learning. The backbone already understands EM structures (membranes, vesicles, organelles) before seeing any gold particles. This is why the model reaches F1=0.93 in just 5 epochs.

Stride-2 Output: Standard CenterNet uses stride 4. At stride 4, a 6nm bead (4-6px radius) collapses to 1 pixel — too small to detect reliably. At stride 2, the same bead occupies 2-3 pixels, enough for Gaussian peak detection.

CornerNet Focal Loss: With positive:negative pixel ratio of 1:23,000, standard BCE would learn to predict all zeros. The focal loss uses (1-p)^alpha weighting to focus on hard examples and (1-gt)^beta penalty reduction near peaks.

Raw Image Input: No preprocessing. The CEM500K backbone was trained on raw EM images. Any heavy preprocessing (top-hat, CLAHE) creates a domain gap and hurts performance. The model learns to distinguish particles from vesicles through training data, not handcrafted filters.


Training Strategy

3-Phase Training with Discriminative Learning Rates

Phase Epochs What's Trainable Learning Rate
1. Warm-up 40 BiFPN + heads only 1e-3
2. Deep unfreeze 40 + layer3 + layer4 1e-5 to 5e-4
3. Full fine-tune 60 All layers 1e-6 to 2e-4
Loss Curve (final model):

Phase 1          Phase 2          Phase 3
|                |                |
1.4 |\           |                |
    | \          |                |
1.0 |  \         |                |
    |   ----     |                |
0.8 |       \    |                |
    |        \   |                |
0.6 |         \--+---             |
    |            |   \            |
0.4 |            |    \---        |
    |            |        \-------+---
0.2 |            |                |
    +---+---+----+---+---+----+---+---+--> Epoch
    0   10  20   40  50  60   80  100 140

Data Augmentation

  • Random 90-degree rotations (EM is rotation-invariant)
  • Horizontal/vertical flips
  • Conservative brightness/contrast (+-8% — preserves the subtle particle signal)
  • Gaussian noise (simulates shot noise)
  • Copy-paste augmentation: real bead crops blended onto training patches
  • 70% hard mining: patches centered on particles, 30% random

Overfitting Prevention

  • Unique patches every epoch: RNG reseeded per sample so the model never sees the same patch twice
  • Early stopping: patience=20 epochs, monitoring validation F1
  • Weight decay: 1e-4 on all parameters

Critical Bugs Found and Fixed

Bug 1: Coordinate Conversion

Problem: CSV files labeled "XY in microns" were assumed to be normalized [0,1] coordinates. They were actual micron values.

Effect: All particle annotations were offset by 50-80 pixels from the real locations. The model was learning to detect particles where none existed.

Fix: Multiply by 1790 px/micron (verified against researcher's color overlay TIFs across 7 synapses).

Bug 2: Heatmap Peak Values

Problem: Gaussian peaks were centered at float coordinates, producing peak values of 0.78-0.93 instead of exactly 1.0.

Effect: The CornerNet focal loss uses pos_mask = (gt == 1.0) to identify positive pixels. With no pixels at exactly 1.0, the model had zero positive training signal. It literally could not learn.

Fix: Center Gaussians at the integer grid point (always produces 1.0). Sub-pixel precision is handled by the offset regression head.

Bug 3: Overfitting on Fixed Patches

Problem: The dataset generated 200 random patches once at initialization. Every epoch replayed the same patches.

Effect: On fast CUDA GPUs, the model memorized all patches in ~17 epochs (loss crashed from 1.6 to 0.002). Validation F1 peaked at 0.66 and degraded.

Fix: Reseed RNG per __getitem__ call so every patch is unique.


Results

Leave-One-Image-Out Cross-Validation (10 folds, 5 seeds each)

Fold Avg F1 Best F1 Notes
S27 0.990 0.994
S8 0.981 0.988
S25 0.972 0.977
S29 0.956 0.966
S1 0.930 0.940
S4 0.919 0.972
S22 0.907 0.938
S13 0.890 0.912
S7 0.799 1.000 Only 3 particles (noisy metric)
S15 0.633 0.667 Only 1 particle (noisy metric)

Mean F1 = 0.943 (8 folds with sufficient annotations)

Per-class Performance (S1 fold, best threshold)

Class Precision Recall F1
6nm (AMPA) 0.895 1.000 0.944
12nm (NR1) 0.833 1.000 0.909

100% recall on both classes — every particle is found. Only errors are a few false positives.

Generalization to Unseen Images

Tested on 15 completely unseen images from a different imaging session. Detections land correctly on particles with no manual tuning. The model successfully detects both 6nm and 12nm particles on:

  • Wild-type (Wt2) samples
  • Heterozygous (Het1) samples
  • Different synapse regions (D1, E3, S1, S10, S12, S18)

System Components

MidasMap/
  config/config.yaml        # All hyperparameters
  src/
    preprocessing.py        # Data loading (10 synapses, 453 particles)
    model.py                # CenterNet: ResNet-50 + BiFPN + heads (24.4M params)
    loss.py                 # CornerNet focal loss + offset regression
    heatmap.py              # GT generation + peak extraction + NMS
    dataset.py              # Patch sampling, augmentation, copy-paste
    postprocess.py          # Mask filter, cross-class NMS
    ensemble.py             # D4 TTA + sliding window inference
    evaluate.py             # Hungarian matching, F1/precision/recall
    visualize.py            # Overlay visualizations
  train.py                  # LOOCV training (--fold, --seed)
  train_final.py            # Final deployable model (all data)
  predict.py                # Inference on new images
  evaluate_loocv.py         # Full evaluation runner
  app.py                    # Gradio web dashboard
  slurm/                    # HPC job scripts
  tests/                    # 36 unit tests

Dashboard

MidasMap includes a web-based dashboard (Gradio) for interactive use:

  1. Upload any TEM image (.tif)
  2. Adjust confidence threshold and NMS parameters
  3. View detections overlaid on the image
  4. Inspect per-class heatmaps
  5. Analyze confidence distributions and spatial patterns
  6. Export results as CSV (particle_id, x_px, y_px, x_um, y_um, class, confidence)
python app.py --checkpoint checkpoints/final/final_model.pth
# Opens at http://localhost:7860

Future Directions

  1. Spatial analytics: distance to synaptic cleft, nearest-neighbor analysis, Ripley's K-function
  2. Size regression head: predict actual bead diameter instead of binary classification
  3. 18nm detection: extend to vGlut2 particles (3-class model)
  4. Active learning: flag low-confidence detections for human review
  5. Cross-protocol generalization: fine-tune on cryo-EM or different staining protocols

Technical Summary

  • Model: CenterNet with CEM500K-pretrained ResNet-50, BiFPN neck, stride-2 output
  • Training: 3-phase with discriminative LRs, 140 epochs, 453 particles / 10 images
  • Evaluation: Leave-one-image-out CV, Hungarian matching, F1 = 0.943
  • Inference: Sliding window (512x512, 128px overlap), ~10s per image on GPU
  • Output: Per-particle (x, y, class, confidence) with optional heatmap visualization