| """ |
| Dynamic FC Temporal Attention model for ASD/TD classification. |
| |
| Architecture (STAGIN-inspired, simplified): |
| Input : (B, W, N) — per-window ROI connectivity strength (mean |FC| per ROI) |
| Step 1 : Linear projection N → H |
| Step 2 : Learnable positional encoding over W time steps |
| Step 3 : Transformer encoder (multi-head self-attention over windows) |
| Step 4 : Attention-weighted pooling over W → subject embedding (H,) |
| Step 5 : MLP classifier → 2 |
| |
| Why this works: |
| ASD shows altered *dynamic* connectivity — not just different mean FC but |
| different temporal patterns of connectivity fluctuation across brain states. |
| The self-attention learns which window combinations are most discriminative. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
|
|
| class DynamicFCAttention(nn.Module): |
|
|
| def __init__( |
| self, |
| num_rois: int = 200, |
| max_windows: int = 30, |
| hidden_dim: int = 128, |
| num_heads: int = 4, |
| num_layers: int = 2, |
| dropout: float = 0.5, |
| num_classes: int = 2, |
| ): |
| super().__init__() |
| assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" |
|
|
| |
| self.input_proj = nn.Sequential( |
| nn.Linear(num_rois, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout * 0.5), |
| ) |
|
|
| |
| self.pos_embed = nn.Parameter(torch.randn(1, max_windows, hidden_dim) * 0.02) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=hidden_dim, |
| nhead=num_heads, |
| dim_feedforward=hidden_dim * 2, |
| dropout=dropout * 0.5, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
|
| |
| self.time_attn = nn.Linear(hidden_dim, 1) |
|
|
| |
| self.head = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.LayerNorm(hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim // 2, num_classes), |
| ) |
|
|
| def forward( |
| self, |
| bold_windows: torch.Tensor, |
| adj: torch.Tensor | None = None, |
| return_attention: bool = False, |
| ) -> torch.Tensor: |
| |
| B, W, N = bold_windows.shape |
|
|
| |
| x = self.input_proj(bold_windows) |
|
|
| |
| x = x + self.pos_embed[:, :W, :] |
|
|
| |
| x = self.transformer(x) |
|
|
| |
| attn = torch.softmax(self.time_attn(x).squeeze(-1), dim=1) |
| embedding = (x * attn.unsqueeze(-1)).sum(dim=1) |
|
|
| logits = self.head(embedding) |
|
|
| if return_attention: |
| return logits, attn |
| return logits |
|
|