|
""" |
|
Learned linear estimator module for OFDM channel estimation. |
|
|
|
This module implements an estimator for transforming channel estimates at |
|
pilot signals to complete channel estimates using a learned linear transformation. |
|
""" |
|
|
|
from typing import Tuple |
|
import logging |
|
import torch |
|
import torch.nn as nn |
|
|
|
from src.config.schemas import SystemConfig, ModelConfig |
|
|
|
|
|
class LinearEstimator(nn.Module): |
|
"""Learned MMSE estimator. |
|
|
|
Find W such that W*h_pilot = h_hat, where h_hat is the estimated channel by stochastic gradient descent on |h_hat - h_ideal|^2 |
|
|
|
Attributes: |
|
device (torch.device): Target device for computation |
|
system_config (SystemConfig): Validated configuration object for OFDM system parameters |
|
model_config (ModelConfig): Validated configuration object for model parameters |
|
ofdm_size (Tuple[int, int]): Dimensions of OFDM frame as (num_subcarriers, num_symbols) |
|
num_subcarriers (int): number of sub-carriers |
|
num_symbols (int): number of OFDM symbols |
|
pilot_size (Tuple[int, int]): Dimensions of pilot signal as (num_subcarriers, num_symbols) |
|
num_subcarriers (int): number of pilots across sub-carriers |
|
num_symbols (int): number of pilots across OFDM symbols |
|
""" |
|
|
|
def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None: |
|
"""Initialize the MMSE estimator. |
|
|
|
Args: |
|
system_config: Validated SystemConfig object containing OFDM system parameters |
|
model_config: Validated ModelConfig object containing model parameters |
|
""" |
|
super().__init__() |
|
|
|
self.system_config = system_config |
|
self.model_config = model_config |
|
self.device = torch.device(model_config.device) |
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
self.ofdm_size = (system_config.ofdm.num_scs, system_config.ofdm.num_symbols) |
|
self.pilot_size = (system_config.pilot.num_scs, system_config.pilot.num_symbols) |
|
|
|
|
|
in_feature_dim = system_config.pilot.num_scs * system_config.pilot.num_symbols |
|
out_feature_dim = system_config.ofdm.num_scs * system_config.ofdm.num_symbols |
|
|
|
self.logger.info(f"Initializing LinearEstimator:") |
|
self.logger.info(f" OFDM size: {self.ofdm_size}") |
|
self.logger.info(f" Pilot size: {self.pilot_size}") |
|
self.logger.info(f" Input features: {in_feature_dim}") |
|
self.logger.info(f" Output features: {out_feature_dim}") |
|
self.logger.info(f" Device: {self.device}") |
|
|
|
|
|
self.linear = nn.Linear(in_feature_dim, out_feature_dim) |
|
self.to(self.device) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass of the MMSE estimator. |
|
|
|
Args: |
|
x: Input tensor containing pilot signals with shape |
|
(batch_size, pilot_size[0], pilot_size[1]) |
|
|
|
Returns: |
|
Estimated OFDM signal tensor with shape |
|
(batch_size, ofdm_size[0], ofdm_size[1]) |
|
""" |
|
|
|
x = x.to(self.device) |
|
self.logger.debug(f"Input shape: {x.size()}") |
|
|
|
|
|
expected_shape = (x.size(0), self.pilot_size[0], self.pilot_size[1]) |
|
if x.size() != expected_shape: |
|
raise ValueError( |
|
f"Expected input shape {expected_shape}, got {x.size()}" |
|
) |
|
|
|
|
|
x = torch.flatten(x, start_dim=1) |
|
self.logger.debug(f"Flattened shape: {x.size()}") |
|
|
|
|
|
x = self.linear(x) |
|
self.logger.debug(f"Linear output shape: {x.size()}") |
|
|
|
|
|
x = x.reshape(-1, self.ofdm_size[0], self.ofdm_size[1]) |
|
self.logger.debug(f"Reshaped output shape: {x.size()}") |
|
|
|
return x |
|
|
|
def __repr__(self) -> str: |
|
"""String representation of the estimator.""" |
|
return ( |
|
f"LinearEstimator(\n" |
|
f" ofdm_size={self.ofdm_size},\n" |
|
f" pilot_size={self.pilot_size},\n" |
|
f" device={self.device}\n" |
|
f")" |
|
) |