Prostate-Inference / docs /architecture.md
Anirudh Balaraman
Update architecture.md
f6b233e unverified

Architecture

Patch Extraction

Patches are extracted using MONAI's RandWeightedCropd (when heatmaps are available) or RandCropByPosNegLabeld (without heatmaps):

  • With heatmaps: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map β€” regions with high DWI and low ADC are sampled more frequently
  • Without heatmaps: Crops are sampled from positive (prostate) regions based on the binary mask

Each scan yields N patches (default 24) of size tile_size x tile_size x depth (default 64x64x3).

Tensor Shape Convention

Throughout the pipeline, tensors follow the shape [B, N, C, D, H, W]:

Dim Meaning Typical Value
B Batch size 4–8
N Number of patches (instances) 24
C Channels (T2 + DWI + ADC) 3
D Depth (slices per patch) 3
H Patch height 64
W Patch width 64

MILModel3D

The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling.

flowchart TD
    A["Input [B, N, C, D, H, W]"] --> B[ResNet18-3D Backbone]
    B --> C[Transformer Encoder<br/>4 layers, 8 heads]
    C --> D[Attention Pooling<br/>512 β†’ 2048 β†’ 1]
    D --> E["Weighted Sum [B, 512]"]
    E --> F["FC Head [B, num_classes]"]

Forward Pass

  1. Backbone: Input is reshaped from [B, N, C, D, H, W] to [B*N, C, D, H, W] and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch.

  2. Transformer: Output from the ResNet 18 encoder is forwarded to the transformer encoder.

  3. Attention: A two-layer attention network (512 β†’ 2048 β†’ 1 with Tanh) computes a scalar weight per patch, normalized via softmax.

  4. Classification: The attention-weighted sum of patch features produces a single [B, 512] vector per scan, which is projected to class logits by a linear layer.

MIL Modes

Mode Aggregation Strategy
mean Average logits across patches
max Max logits across patches
att Attention-weighted feature pooling
att_trans Transformer encoder + attention pooling (primary mode)
att_trans_pyramid Pyramid transformer on intermediate ResNet layers + attention

The default and primary mode is att_trans.

csPCa_Model

Wraps a frozen MILModel_3D backbone and replaces the classification head:

flowchart TD
    A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone<br/>(ResNet18 + Transformer)"]
    B --> C["Pooled Features [B, 512]"]
    C --> D["SimpleNN Head<br/>512 β†’ 256 β†’ 128 β†’ 1"]
    D --> E["Sigmoid β†’ csPCa Probability"]

SimpleNN

Linear(512, 256) β†’ ReLU
Linear(256, 128) β†’ ReLU β†’ Dropout(0.3)
Linear(128, 1) β†’ Sigmoid

During csPCa training, the backbone's net (ResNet18), transformer are frozen. The attention module and SimpleNN head remain trainable.

Attention Loss

During PI-RADS training with heatmaps enabled, the model uses a dual-loss objective:

total_loss = class_loss + lambda_att * attention_loss
  • Classification loss: Standard CrossEntropy on PI-RADS labels
  • Attention loss: 1 - cosine_similarity(predicted_attention, heatmap_attention)
    • Heatmap-derived attention labels are computed by summing spatial heatmap values per patch, squaring for sharpness, and normalizing
    • PI-RADS 2 samples get uniform attention (no expected lesion)
    • lambda_att warms up linearly from 0 to 2.0 over the first 25 epochs
    • The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification