[Segmentation Mask] --> [Encoder (U-Net style or ViT Patch Embedding)] ---> Q

[Text Condition: "CT", "T1-MR"] --> [Condition Embedding] --> K, V

Q, K, V ---> [Cross-Attention Block] ---> Fused Feature Map --> UNet Backbone --> DDPM/Diffusion Head


In [None]:
# Let's write the PyTorch code for the Modality Field Adapter (MFA) module.
import torch
import torch.nn as nn
import torch.nn.functional as F

class ModalityFieldAdapter(nn.Module):
 def __init__(self, in_channels, cond_dim, embed_dim=128, num_heads=4):
 super(ModalityFieldAdapter, self).__init__()
 self.anatomy_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=1)
 self.modality_fc = nn.Linear(cond_dim, embed_dim)
 self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
 self.out_proj = nn.Conv2d(embed_dim, in_channels, kernel_size=1)
 self.norm = nn.LayerNorm(embed_dim)

 def forward(self, x, modality_cond):
 """
 x: segmentation-based feature map (B, C, H, W)
 modality_cond: modality condition vector (B, cond_dim), e.g., one-hot [1,0,0] for CT
 """
 B, C, H, W = x.shape

 # Project anatomy features to token space
 anatomy_feat = self.anatomy_proj(x) # (B, embed_dim, H, W)
 anatomy_tokens = anatomy_feat.flatten(2).transpose(1, 2) # (B, HW, embed_dim)

 # Get modality embedding and expand
 modality_embed = self.modality_fc(modality_cond).unsqueeze(1) # (B, 1, embed_dim)

 # Cross attention: Q=modality, K/V=anatomy tokens
 attn_out, _ = self.cross_attn(query=modality_embed, key=anatomy_tokens, value=anatomy_tokens) # (B, 1, embed_dim)

 # Broadcast attention output back to spatial map
 attn_map = attn_out.repeat(1, H * W, 1).reshape(B, H, W, -1).permute(0, 3, 1, 2) # (B, embed_dim, H, W)

 # Combine with anatomy features
 fused = anatomy_feat + attn_map
 fused = self.norm(fused.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, -1, H, W)

 return self.out_proj(fused) # (B, C, H, W)

# Sample instantiation
mfa = ModalityFieldAdapter(in_channels=1, cond_dim=3)
mfa
