BerkIGuler's picture
fixes on src/models
687eaba
"""
Learned linear estimator module for OFDM channel estimation.
This module implements an estimator for transforming channel estimates at
pilot signals to complete channel estimates using a learned linear transformation.
"""
from typing import Tuple
import logging
import torch
import torch.nn as nn
from src.config.schemas import SystemConfig, ModelConfig
class LinearEstimator(nn.Module):
"""Learned MMSE estimator.
Find W such that W*h_pilot = h_hat, where h_hat is the estimated channel by stochastic gradient descent on |h_hat - h_ideal|^2
Attributes:
device (torch.device): Target device for computation
system_config (SystemConfig): Validated configuration object for OFDM system parameters
model_config (ModelConfig): Validated configuration object for model parameters
ofdm_size (Tuple[int, int]): Dimensions of OFDM frame as (num_subcarriers, num_symbols)
num_subcarriers (int): number of sub-carriers
num_symbols (int): number of OFDM symbols
pilot_size (Tuple[int, int]): Dimensions of pilot signal as (num_subcarriers, num_symbols)
num_subcarriers (int): number of pilots across sub-carriers
num_symbols (int): number of pilots across OFDM symbols
"""
def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
"""Initialize the MMSE estimator.
Args:
system_config: Validated SystemConfig object containing OFDM system parameters
model_config: Validated ModelConfig object containing model parameters
"""
super().__init__()
self.system_config = system_config
self.model_config = model_config
self.device = torch.device(model_config.device)
self.logger = logging.getLogger(__name__)
# Extract dimensions from validated config
self.ofdm_size = (system_config.ofdm.num_scs, system_config.ofdm.num_symbols)
self.pilot_size = (system_config.pilot.num_scs, system_config.pilot.num_symbols)
# Calculate feature dimensions
in_feature_dim = system_config.pilot.num_scs * system_config.pilot.num_symbols
out_feature_dim = system_config.ofdm.num_scs * system_config.ofdm.num_symbols
self.logger.info(f"Initializing LinearEstimator:")
self.logger.info(f" OFDM size: {self.ofdm_size}")
self.logger.info(f" Pilot size: {self.pilot_size}")
self.logger.info(f" Input features: {in_feature_dim}")
self.logger.info(f" Output features: {out_feature_dim}")
self.logger.info(f" Device: {self.device}")
# Create linear layer
self.linear = nn.Linear(in_feature_dim, out_feature_dim)
self.to(self.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the MMSE estimator.
Args:
x: Input tensor containing pilot signals with shape
(batch_size, pilot_size[0], pilot_size[1])
Returns:
Estimated OFDM signal tensor with shape
(batch_size, ofdm_size[0], ofdm_size[1])
"""
# pytorch does nothing if input is already on correct device
x = x.to(self.device)
self.logger.debug(f"Input shape: {x.size()}")
# Validate input shape
expected_shape = (x.size(0), self.pilot_size[0], self.pilot_size[1])
if x.size() != expected_shape:
raise ValueError(
f"Expected input shape {expected_shape}, got {x.size()}"
)
# Flatten input for linear transformation
x = torch.flatten(x, start_dim=1)
self.logger.debug(f"Flattened shape: {x.size()}")
# Apply linear transformation
x = self.linear(x)
self.logger.debug(f"Linear output shape: {x.size()}")
# Reshape to OFDM dimensions
x = x.reshape(-1, self.ofdm_size[0], self.ofdm_size[1])
self.logger.debug(f"Reshaped output shape: {x.size()}")
return x
def __repr__(self) -> str:
"""String representation of the estimator."""
return (
f"LinearEstimator(\n"
f" ofdm_size={self.ofdm_size},\n"
f" pilot_size={self.pilot_size},\n"
f" device={self.device}\n"
f")"
)