|
import torch |
|
from torch import nn |
|
import logging |
|
from typing import Tuple, List, Optional |
|
|
|
from src.config.schemas import SystemConfig, ModelConfig |
|
from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, ChannelAdapter |
|
|
|
|
|
class BaseFortiTranEstimator(nn.Module): |
|
""" |
|
Base Hybrid CNN-Transformer Channel Estimator for OFDM Systems. |
|
|
|
This model performs channel estimation by: |
|
1. Upsampling pilot symbols to full OFDM grid size (with linear layer) |
|
2. Applying convolutional enhancement for subcarrier-symbol features |
|
3. Converting to patch embeddings for transformer processing |
|
4. Using transformer encoder to capture long-range dependencies |
|
5. Reconstructing subcarrier-symbol representation and applying residual connections |
|
6. Final convolutional refinement for high-quality channel estimates |
|
""" |
|
|
|
def __init__(self, system_config: SystemConfig, model_config: ModelConfig, |
|
use_channel_adaptation: bool = False) -> None: |
|
""" |
|
Initialize the BaseFortiTranEstimator. |
|
|
|
Args: |
|
system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement) |
|
model_config: Model architecture configuration (patch size, layers, etc.) |
|
use_channel_adaptation: Whether to enable channel adaptation features (disabled for FortiTran) |
|
""" |
|
super().__init__() |
|
|
|
self.system_config = system_config |
|
self.model_config = model_config |
|
self.use_channel_adaptation = use_channel_adaptation |
|
self.device = torch.device(model_config.device) |
|
self.logger = logging.getLogger(self.__class__.__name__) |
|
|
|
|
|
self._setup_dimensions() |
|
|
|
|
|
self._build_architecture() |
|
|
|
|
|
self.to(self.device) |
|
|
|
self._log_initialization_info() |
|
|
|
def _setup_dimensions(self) -> None: |
|
"""Calculate and cache key dimensions from configuration.""" |
|
|
|
self.ofdm_size = ( |
|
self.system_config.ofdm.num_scs, |
|
self.system_config.ofdm.num_symbols |
|
) |
|
|
|
|
|
self.pilot_size = ( |
|
self.system_config.pilot.num_scs, |
|
self.system_config.pilot.num_symbols |
|
) |
|
|
|
|
|
self.pilot_features = self.pilot_size[0] * self.pilot_size[1] |
|
self.ofdm_features = self.ofdm_size[0] * self.ofdm_size[1] |
|
|
|
|
|
self.patch_length = ( |
|
self.model_config.patch_size[0] * self.model_config.patch_size[1] |
|
) |
|
|
|
|
|
if self.use_channel_adaptation: |
|
if self.model_config.adaptive_token_length is None: |
|
raise ValueError("adaptive_token_length must be set when channel adaptation is enabled") |
|
self.transformer_input_dim = self.patch_length + self.model_config.adaptive_token_length |
|
else: |
|
self.transformer_input_dim = self.patch_length |
|
|
|
def _build_architecture(self) -> None: |
|
"""Construct the model architecture components.""" |
|
|
|
self.pilot_upsampler = nn.Linear(self.pilot_features, self.ofdm_features) |
|
|
|
|
|
self.initial_enhancer = ConvEnhancer() |
|
|
|
|
|
self.patch_embedder = PatchEmbedding(self.model_config.patch_size) |
|
|
|
|
|
if self.use_channel_adaptation: |
|
if self.model_config.channel_adaptivity_hidden_sizes is None: |
|
raise ValueError("channel_adaptivity_hidden_sizes must be set when channel adaptation is enabled") |
|
|
|
hidden_sizes = tuple(self.model_config.channel_adaptivity_hidden_sizes) |
|
if len(hidden_sizes) != 3: |
|
raise ValueError("channel_adaptivity_hidden_sizes must have exactly 3 values") |
|
self.channel_adapter = ChannelAdapter(hidden_sizes) |
|
|
|
|
|
transformer_output_dim = self.patch_length |
|
|
|
self.transformer_encoder = TransformerEncoderForChannels( |
|
input_dim=self.transformer_input_dim, |
|
output_dim=transformer_output_dim, |
|
model_dim=self.model_config.model_dim, |
|
num_head=self.model_config.num_head, |
|
activation=self.model_config.activation, |
|
dropout=self.model_config.dropout, |
|
num_layers=self.model_config.num_layers, |
|
max_len=self.model_config.max_seq_len, |
|
pos_encoding_type=self.model_config.pos_encoding_type |
|
) |
|
|
|
|
|
self.patch_reconstructor = InversePatchEmbedding( |
|
self.ofdm_size, |
|
self.model_config.patch_size |
|
) |
|
|
|
|
|
self.final_refiner = ConvEnhancer() |
|
|
|
def _log_initialization_info(self) -> None: |
|
"""Log model initialization details.""" |
|
adaptation_status = "enabled" if self.use_channel_adaptation else "disabled" |
|
self.logger.info(f"{self.__class__.__name__} initialized successfully:") |
|
self.logger.info(f" Channel adaptation: {adaptation_status}") |
|
self.logger.info(f" OFDM grid: {self.ofdm_size[0]}×{self.ofdm_size[1]} = {self.ofdm_features} elements") |
|
self.logger.info(f" Pilot grid: {self.pilot_size[0]}×{self.pilot_size[1]} = {self.pilot_features} elements") |
|
self.logger.info(f" Patch size: {self.model_config.patch_size}") |
|
self.logger.info(f" Model dimension: {self.model_config.model_dim}") |
|
self.logger.info(f" Transformer layers: {self.model_config.num_layers}") |
|
self.logger.info(f" Device: {self.device}") |
|
|
|
total_params = sum(p.numel() for p in self.parameters()) |
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
self.logger.info(f" Total parameters: {total_params:,}") |
|
self.logger.info(f" Trainable parameters: {trainable_params:,}") |
|
|
|
def forward(self, pilot_symbols: torch.Tensor, meta_data: Optional[Tuple] = None) -> torch.Tensor: |
|
""" |
|
Forward pass for channel estimation. |
|
|
|
Args: |
|
pilot_symbols: Complex pilot symbols of shape [batch, pilot_scs, pilot_symbols] |
|
meta_data: Channel conditions (only used if channel adaptation is enabled) |
|
|
|
Returns: |
|
Estimated channel matrix of shape [batch, ofdm_scs, ofdm_symbols] |
|
""" |
|
|
|
if self.use_channel_adaptation and meta_data is None: |
|
raise ValueError("meta_data is required when channel adaptation is enabled") |
|
|
|
if not self.use_channel_adaptation and meta_data is not None: |
|
self.logger.warning("meta_data provided but channel adaptation is disabled - ignoring meta_data") |
|
|
|
|
|
channel_conditions = None |
|
if self.use_channel_adaptation and meta_data is not None: |
|
_, snr, delay_spread, max_dop_shift, _, _ = meta_data |
|
channel_conditions = [ |
|
tensor.to(self.device) |
|
for tensor in (snr, delay_spread, max_dop_shift) |
|
] |
|
|
|
|
|
pilot_symbols = pilot_symbols.to(self.device) |
|
|
|
|
|
real_estimate = self._forward_real_valued(pilot_symbols.real, channel_conditions) |
|
imag_estimate = self._forward_real_valued(pilot_symbols.imag, channel_conditions) |
|
|
|
|
|
channel_estimate = torch.complex(real_estimate, imag_estimate) |
|
|
|
return channel_estimate |
|
|
|
def _forward_real_valued(self, x: torch.Tensor, |
|
channel_conditions: Optional[List[torch.Tensor]] = None) -> torch.Tensor: |
|
""" |
|
Process real-valued input through the estimation pipeline. |
|
|
|
Args: |
|
x: Real-valued input tensor [batch, pilot_features] or [batch, pilot_scs, pilot_symbols] |
|
channel_conditions: Channel conditions for adaptation (optional) |
|
|
|
Returns: |
|
Real-valued channel estimate [batch, ofdm_scs, ofdm_symbols] |
|
""" |
|
batch_size = x.shape[0] |
|
|
|
|
|
if x.dim() > 2: |
|
x = x.view(batch_size, -1) |
|
|
|
|
|
upsampled = self.pilot_upsampler(x) |
|
|
|
|
|
upsampled_2d = upsampled.view(batch_size, 1, *self.ofdm_size) |
|
|
|
|
|
conv_enhanced = torch.squeeze(self.initial_enhancer(upsampled_2d), dim=1) |
|
|
|
|
|
patch_embeddings = self.patch_embedder(conv_enhanced) |
|
|
|
|
|
if self.use_channel_adaptation and channel_conditions is not None: |
|
encoded_channel_condition = self.channel_adapter(*channel_conditions) |
|
transformer_input = torch.cat((patch_embeddings, encoded_channel_condition), dim=2) |
|
else: |
|
transformer_input = patch_embeddings |
|
|
|
|
|
transformer_output = self.transformer_encoder(transformer_input) |
|
|
|
|
|
reconstructed = self.patch_reconstructor(transformer_output) |
|
|
|
|
|
residual_combined = conv_enhanced + reconstructed |
|
|
|
|
|
refined_output = torch.squeeze(self.final_refiner(torch.unsqueeze(residual_combined, dim=1)), dim=1) |
|
|
|
return refined_output |
|
|
|
def get_model_info(self) -> dict: |
|
"""Return model configuration and statistics.""" |
|
return { |
|
'model_name': self.__class__.__name__, |
|
'channel_adaptation': self.use_channel_adaptation, |
|
'ofdm_size': self.ofdm_size, |
|
'pilot_size': self.pilot_size, |
|
'patch_size': self.model_config.patch_size, |
|
'patch_length': self.patch_length, |
|
'transformer_input_dim': self.transformer_input_dim, |
|
'model_dim': self.model_config.model_dim, |
|
'num_layers': self.model_config.num_layers, |
|
'device': str(self.device), |
|
'total_parameters': sum(p.numel() for p in self.parameters()), |
|
'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
} |
|
|
|
|
|
class FortiTranEstimator(BaseFortiTranEstimator): |
|
""" |
|
Standard Hybrid CNN-Transformer Channel Estimator for OFDM Systems. |
|
|
|
This is the base version without channel adaptation features. |
|
""" |
|
|
|
def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None: |
|
""" |
|
Initialize the FortiTranEstimator. |
|
|
|
Args: |
|
system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement) |
|
model_config: Model architecture configuration (patch size, layers, etc.) |
|
""" |
|
super().__init__(system_config, model_config, use_channel_adaptation=False) |
|
|