{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[Segmentation Mask] --> [Encoder (U-Net style or ViT Patch Embedding)] ---> Q\n", "\n", "[Text Condition: \"CT\", \"T1-MR\"] --> [Condition Embedding] --> K, V\n", "\n", "Q, K, V ---> [Cross-Attention Block] ---> Fused Feature Map --> UNet Backbone --> DDPM/Diffusion Head\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's write the PyTorch code for the Modality Field Adapter (MFA) module.\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class ModalityFieldAdapter(nn.Module):\n", " def __init__(self, in_channels, cond_dim, embed_dim=128, num_heads=4):\n", " super(ModalityFieldAdapter, self).__init__()\n", " self.anatomy_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=1)\n", " self.modality_fc = nn.Linear(cond_dim, embed_dim)\n", " self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)\n", " self.out_proj = nn.Conv2d(embed_dim, in_channels, kernel_size=1)\n", " self.norm = nn.LayerNorm(embed_dim)\n", "\n", " def forward(self, x, modality_cond):\n", " \"\"\"\n", " x: segmentation-based feature map (B, C, H, W)\n", " modality_cond: modality condition vector (B, cond_dim), e.g., one-hot [1,0,0] for CT\n", " \"\"\"\n", " B, C, H, W = x.shape\n", "\n", " # Project anatomy features to token space\n", " anatomy_feat = self.anatomy_proj(x) # (B, embed_dim, H, W)\n", " anatomy_tokens = anatomy_feat.flatten(2).transpose(1, 2) # (B, HW, embed_dim)\n", "\n", " # Get modality embedding and expand\n", " modality_embed = self.modality_fc(modality_cond).unsqueeze(1) # (B, 1, embed_dim)\n", "\n", " # Cross attention: Q=modality, K/V=anatomy tokens\n", " attn_out, _ = self.cross_attn(query=modality_embed, key=anatomy_tokens, value=anatomy_tokens) # (B, 1, embed_dim)\n", "\n", " # Broadcast attention output back to spatial map\n", " attn_map = attn_out.repeat(1, H * W, 1).reshape(B, H, W, -1).permute(0, 3, 1, 2) # (B, embed_dim, H, W)\n", "\n", " # Combine with anatomy features\n", " fused = anatomy_feat + attn_map\n", " fused = self.norm(fused.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, -1, H, W)\n", "\n", " return self.out_proj(fused) # (B, C, H, W)\n", "\n", "# Sample instantiation\n", "mfa = ModalityFieldAdapter(in_channels=1, cond_dim=3)\n", "mfa\n" ] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }