AdaFortiTran / src /models /fortitran.py
BerkIGuler's picture
fixes on src/models
687eaba
import torch
from torch import nn
import logging
from typing import Tuple, List, Optional
from src.config.schemas import SystemConfig, ModelConfig
from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, ChannelAdapter
class BaseFortiTranEstimator(nn.Module):
"""
Base Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
This model performs channel estimation by:
1. Upsampling pilot symbols to full OFDM grid size (with linear layer)
2. Applying convolutional enhancement for subcarrier-symbol features
3. Converting to patch embeddings for transformer processing
4. Using transformer encoder to capture long-range dependencies
5. Reconstructing subcarrier-symbol representation and applying residual connections
6. Final convolutional refinement for high-quality channel estimates
"""
def __init__(self, system_config: SystemConfig, model_config: ModelConfig,
use_channel_adaptation: bool = False) -> None:
"""
Initialize the BaseFortiTranEstimator.
Args:
system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
model_config: Model architecture configuration (patch size, layers, etc.)
use_channel_adaptation: Whether to enable channel adaptation features (disabled for FortiTran)
"""
super().__init__()
self.system_config = system_config
self.model_config = model_config
self.use_channel_adaptation = use_channel_adaptation
self.device = torch.device(model_config.device)
self.logger = logging.getLogger(self.__class__.__name__)
# Cache key dimensions for efficiency
self._setup_dimensions()
# Initialize model components
self._build_architecture()
# Move model to specified device
self.to(self.device)
self._log_initialization_info()
def _setup_dimensions(self) -> None:
"""Calculate and cache key dimensions from configuration."""
# OFDM grid dimensions
self.ofdm_size = (
self.system_config.ofdm.num_scs,
self.system_config.ofdm.num_symbols
)
# Pilot arrangement dimensions
self.pilot_size = (
self.system_config.pilot.num_scs,
self.system_config.pilot.num_symbols
)
# Feature dimensions for linear layers
self.pilot_features = self.pilot_size[0] * self.pilot_size[1]
self.ofdm_features = self.ofdm_size[0] * self.ofdm_size[1]
# Patch processing dimensions
self.patch_length = (
self.model_config.patch_size[0] * self.model_config.patch_size[1]
)
# Transformer input dimension (includes channel tokens if adaptation is enabled)
if self.use_channel_adaptation:
if self.model_config.adaptive_token_length is None:
raise ValueError("adaptive_token_length must be set when channel adaptation is enabled")
self.transformer_input_dim = self.patch_length + self.model_config.adaptive_token_length
else:
self.transformer_input_dim = self.patch_length
def _build_architecture(self) -> None:
"""Construct the model architecture components."""
# 1. Pilot-to-OFDM upsampling
self.pilot_upsampler = nn.Linear(self.pilot_features, self.ofdm_features)
# 2. Initial convolutional enhancement
self.initial_enhancer = ConvEnhancer()
# 3. Patch embedding for transformer processing
self.patch_embedder = PatchEmbedding(self.model_config.patch_size)
# 4. Channel adapter (conditional on use_channel_adaptation)
if self.use_channel_adaptation:
if self.model_config.channel_adaptivity_hidden_sizes is None:
raise ValueError("channel_adaptivity_hidden_sizes must be set when channel adaptation is enabled")
# Convert list to tuple as expected by ChannelAdapter (exactly 3 values)
hidden_sizes = tuple(self.model_config.channel_adaptivity_hidden_sizes)
if len(hidden_sizes) != 3:
raise ValueError("channel_adaptivity_hidden_sizes must have exactly 3 values")
self.channel_adapter = ChannelAdapter(hidden_sizes)
# 5. Transformer encoder for sequence modeling
transformer_output_dim = self.patch_length # Always output standard patch length
self.transformer_encoder = TransformerEncoderForChannels(
input_dim=self.transformer_input_dim,
output_dim=transformer_output_dim,
model_dim=self.model_config.model_dim,
num_head=self.model_config.num_head,
activation=self.model_config.activation,
dropout=self.model_config.dropout,
num_layers=self.model_config.num_layers,
max_len=self.model_config.max_seq_len,
pos_encoding_type=self.model_config.pos_encoding_type
)
# 6. Patch reconstruction
self.patch_reconstructor = InversePatchEmbedding(
self.ofdm_size,
self.model_config.patch_size
)
# 7. Final convolutional refinement
self.final_refiner = ConvEnhancer()
def _log_initialization_info(self) -> None:
"""Log model initialization details."""
adaptation_status = "enabled" if self.use_channel_adaptation else "disabled"
self.logger.info(f"{self.__class__.__name__} initialized successfully:")
self.logger.info(f" Channel adaptation: {adaptation_status}")
self.logger.info(f" OFDM grid: {self.ofdm_size[0]}×{self.ofdm_size[1]} = {self.ofdm_features} elements")
self.logger.info(f" Pilot grid: {self.pilot_size[0]}×{self.pilot_size[1]} = {self.pilot_features} elements")
self.logger.info(f" Patch size: {self.model_config.patch_size}")
self.logger.info(f" Model dimension: {self.model_config.model_dim}")
self.logger.info(f" Transformer layers: {self.model_config.num_layers}")
self.logger.info(f" Device: {self.device}")
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
self.logger.info(f" Total parameters: {total_params:,}")
self.logger.info(f" Trainable parameters: {trainable_params:,}")
def forward(self, pilot_symbols: torch.Tensor, meta_data: Optional[Tuple] = None) -> torch.Tensor:
"""
Forward pass for channel estimation.
Args:
pilot_symbols: Complex pilot symbols of shape [batch, pilot_scs, pilot_symbols]
meta_data: Channel conditions (only used if channel adaptation is enabled)
Returns:
Estimated channel matrix of shape [batch, ofdm_scs, ofdm_symbols]
"""
# Validate inputs based on adaptation mode
if self.use_channel_adaptation and meta_data is None:
raise ValueError("meta_data is required when channel adaptation is enabled")
if not self.use_channel_adaptation and meta_data is not None:
self.logger.warning("meta_data provided but channel adaptation is disabled - ignoring meta_data")
# Extract channel conditions if adaptation is enabled
channel_conditions = None
if self.use_channel_adaptation and meta_data is not None:
_, snr, delay_spread, max_dop_shift, _, _ = meta_data
channel_conditions = [
tensor.to(self.device)
for tensor in (snr, delay_spread, max_dop_shift)
]
# Ensure input is on correct device
pilot_symbols = pilot_symbols.to(self.device)
# Process real and imaginary parts separately
real_estimate = self._forward_real_valued(pilot_symbols.real, channel_conditions)
imag_estimate = self._forward_real_valued(pilot_symbols.imag, channel_conditions)
# Combine into complex tensor
channel_estimate = torch.complex(real_estimate, imag_estimate)
return channel_estimate
def _forward_real_valued(self, x: torch.Tensor,
channel_conditions: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
"""
Process real-valued input through the estimation pipeline.
Args:
x: Real-valued input tensor [batch, pilot_features] or [batch, pilot_scs, pilot_symbols]
channel_conditions: Channel conditions for adaptation (optional)
Returns:
Real-valued channel estimate [batch, ofdm_scs, ofdm_symbols]
"""
batch_size = x.shape[0]
# Flatten subcarrier and symbol dimensions for linear upsampling
if x.dim() > 2:
x = x.view(batch_size, -1)
# Stage 1: Upsample from pilot grid to OFDM grid
upsampled = self.pilot_upsampler(x)
# Reshape for convolutional processing
upsampled_2d = upsampled.view(batch_size, 1, *self.ofdm_size)
# Stage 2: Initial convolutional enhancement
conv_enhanced = torch.squeeze(self.initial_enhancer(upsampled_2d), dim=1)
# Stage 3: Convert to patch embeddings
patch_embeddings = self.patch_embedder(conv_enhanced)
# Stage 4: Apply channel adaptation if enabled
if self.use_channel_adaptation and channel_conditions is not None:
encoded_channel_condition = self.channel_adapter(*channel_conditions)
transformer_input = torch.cat((patch_embeddings, encoded_channel_condition), dim=2)
else:
transformer_input = patch_embeddings
# Stage 5: Transformer processing for long-range dependencies
transformer_output = self.transformer_encoder(transformer_input)
# Stage 6: Reconstruct subcarrier-symbol representation
reconstructed = self.patch_reconstructor(transformer_output)
# Stage 7: Apply residual connection
residual_combined = conv_enhanced + reconstructed
# Stage 8: Final convolutional refinement
refined_output = torch.squeeze(self.final_refiner(torch.unsqueeze(residual_combined, dim=1)), dim=1)
return refined_output
def get_model_info(self) -> dict:
"""Return model configuration and statistics."""
return {
'model_name': self.__class__.__name__,
'channel_adaptation': self.use_channel_adaptation,
'ofdm_size': self.ofdm_size,
'pilot_size': self.pilot_size,
'patch_size': self.model_config.patch_size,
'patch_length': self.patch_length,
'transformer_input_dim': self.transformer_input_dim,
'model_dim': self.model_config.model_dim,
'num_layers': self.model_config.num_layers,
'device': str(self.device),
'total_parameters': sum(p.numel() for p in self.parameters()),
'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
}
class FortiTranEstimator(BaseFortiTranEstimator):
"""
Standard Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
This is the base version without channel adaptation features.
"""
def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
"""
Initialize the FortiTranEstimator.
Args:
system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
model_config: Model architecture configuration (patch size, layers, etc.)
"""
super().__init__(system_config, model_config, use_channel_adaptation=False)