|
|
""" |
|
|
Adaptive Fusion Module for Hybrid Food Classifier |
|
|
Combines CNN and ViT features using cross-attention mechanism |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Tuple |
|
|
|
|
|
class AdaptiveFusionModule(nn.Module): |
|
|
"""Adaptive fusion module with cross-attention""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
feature_dim: int = 768, |
|
|
hidden_dim: int = 512, |
|
|
num_heads: int = 8, |
|
|
dropout: float = 0.2, |
|
|
spatial_size: int = 7 |
|
|
): |
|
|
super(AdaptiveFusionModule, self).__init__() |
|
|
|
|
|
self.feature_dim = feature_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.num_heads = num_heads |
|
|
self.spatial_size = spatial_size |
|
|
|
|
|
|
|
|
self.cnn_to_vit_attention = nn.MultiheadAttention( |
|
|
embed_dim=feature_dim, |
|
|
num_heads=num_heads, |
|
|
dropout=dropout, |
|
|
batch_first=True |
|
|
) |
|
|
|
|
|
|
|
|
self.vit_to_cnn_attention = nn.MultiheadAttention( |
|
|
embed_dim=feature_dim, |
|
|
num_heads=num_heads, |
|
|
dropout=dropout, |
|
|
batch_first=True |
|
|
) |
|
|
|
|
|
|
|
|
self.self_attention = nn.MultiheadAttention( |
|
|
embed_dim=feature_dim, |
|
|
num_heads=num_heads, |
|
|
dropout=dropout, |
|
|
batch_first=True |
|
|
) |
|
|
|
|
|
|
|
|
self.cnn_spatial_proj = nn.Sequential( |
|
|
nn.Linear(feature_dim, feature_dim), |
|
|
nn.LayerNorm(feature_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
self.vit_spatial_proj = nn.Sequential( |
|
|
nn.Linear(feature_dim, feature_dim), |
|
|
nn.LayerNorm(feature_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
|
|
|
self.global_fusion = nn.Sequential( |
|
|
nn.Linear(feature_dim * 2, hidden_dim), |
|
|
nn.LayerNorm(hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, feature_dim), |
|
|
nn.LayerNorm(feature_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
|
|
|
self.adaptive_weight = nn.Sequential( |
|
|
nn.Linear(feature_dim * 2, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_dim, 2), |
|
|
nn.Softmax(dim=-1) |
|
|
) |
|
|
|
|
|
|
|
|
self.final_proj = nn.Sequential( |
|
|
nn.Linear(feature_dim, hidden_dim), |
|
|
nn.LayerNorm(hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
cnn_spatial: torch.Tensor, |
|
|
cnn_global: torch.Tensor, |
|
|
vit_spatial: torch.Tensor, |
|
|
vit_global: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Forward pass |
|
|
|
|
|
Args: |
|
|
cnn_spatial: CNN spatial features [B, feature_dim, 7, 7] |
|
|
cnn_global: CNN global features [B, feature_dim] |
|
|
vit_spatial: ViT patch features [B, num_patches, feature_dim] |
|
|
vit_global: ViT CLS token features [B, feature_dim] |
|
|
|
|
|
Returns: |
|
|
fused_spatial: Fused spatial features [B, seq_len, feature_dim] |
|
|
fused_global: Fused global features [B, feature_dim] |
|
|
""" |
|
|
batch_size = cnn_spatial.size(0) |
|
|
|
|
|
|
|
|
cnn_spatial_seq = cnn_spatial.flatten(2).transpose(1, 2) |
|
|
|
|
|
|
|
|
cnn_spatial_proj = self.cnn_spatial_proj(cnn_spatial_seq) |
|
|
vit_spatial_proj = self.vit_spatial_proj(vit_spatial) |
|
|
|
|
|
|
|
|
cnn_attended, _ = self.cnn_to_vit_attention( |
|
|
query=cnn_spatial_proj, |
|
|
key=vit_spatial_proj, |
|
|
value=vit_spatial_proj |
|
|
) |
|
|
|
|
|
|
|
|
vit_attended, _ = self.vit_to_cnn_attention( |
|
|
query=vit_spatial_proj, |
|
|
key=cnn_spatial_proj, |
|
|
value=cnn_spatial_proj |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
combined_spatial = torch.cat([ |
|
|
cnn_attended + cnn_spatial_proj, |
|
|
vit_attended + vit_spatial_proj |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
fused_spatial, _ = self.self_attention( |
|
|
query=combined_spatial, |
|
|
key=combined_spatial, |
|
|
value=combined_spatial |
|
|
) |
|
|
|
|
|
|
|
|
global_concat = torch.cat([cnn_global, vit_global], dim=-1) |
|
|
fused_global_base = self.global_fusion(global_concat) |
|
|
|
|
|
|
|
|
weights = self.adaptive_weight(global_concat) |
|
|
cnn_weight = weights[:, 0:1] |
|
|
vit_weight = weights[:, 1:2] |
|
|
|
|
|
|
|
|
fused_global = (cnn_weight * cnn_global + |
|
|
vit_weight * vit_global + |
|
|
fused_global_base) / 2 |
|
|
|
|
|
|
|
|
fused_global = self.final_proj(fused_global) |
|
|
|
|
|
return fused_spatial, fused_global |
|
|
|
|
|
def get_output_dim(self) -> int: |
|
|
"""Get output feature dimension""" |
|
|
return self.hidden_dim |