File size: 4,311 Bytes
2ef7f80
 
 
 
 
 
 
 
 
 
 
 
687eaba
2ef7f80
 
 
 
 
687eaba
 
2ef7f80
 
687eaba
 
 
 
 
 
 
 
2ef7f80
 
687eaba
2ef7f80
 
 
687eaba
 
2ef7f80
 
 
687eaba
 
 
2ef7f80
 
 
687eaba
 
2ef7f80
 
687eaba
 
2ef7f80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687eaba
2ef7f80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Learned linear estimator module for OFDM channel estimation.

This module implements an estimator for transforming channel estimates at
pilot signals to complete channel estimates using a learned linear transformation.
"""

from typing import Tuple
import logging
import torch
import torch.nn as nn

from src.config.schemas import SystemConfig, ModelConfig


class LinearEstimator(nn.Module):
    """Learned MMSE estimator.

    Find W such that W*h_pilot = h_hat, where h_hat is the estimated channel by stochastic gradient descent on |h_hat - h_ideal|^2

    Attributes:
        device (torch.device): Target device for computation
        system_config (SystemConfig): Validated configuration object for OFDM system parameters
        model_config (ModelConfig): Validated configuration object for model parameters
        ofdm_size (Tuple[int, int]): Dimensions of OFDM frame as (num_subcarriers, num_symbols)
            num_subcarriers (int): number of sub-carriers
            num_symbols (int): number of OFDM symbols
        pilot_size (Tuple[int, int]): Dimensions of pilot signal as (num_subcarriers, num_symbols)
            num_subcarriers (int): number of pilots across sub-carriers
            num_symbols (int): number of pilots across OFDM symbols
    """

    def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
        """Initialize the MMSE estimator.

        Args:
            system_config: Validated SystemConfig object containing OFDM system parameters
            model_config: Validated ModelConfig object containing model parameters
        """
        super().__init__()

        self.system_config = system_config
        self.model_config = model_config
        self.device = torch.device(model_config.device)
        self.logger = logging.getLogger(__name__)

        # Extract dimensions from validated config
        self.ofdm_size = (system_config.ofdm.num_scs, system_config.ofdm.num_symbols)
        self.pilot_size = (system_config.pilot.num_scs, system_config.pilot.num_symbols)

        # Calculate feature dimensions
        in_feature_dim = system_config.pilot.num_scs * system_config.pilot.num_symbols
        out_feature_dim = system_config.ofdm.num_scs * system_config.ofdm.num_symbols

        self.logger.info(f"Initializing LinearEstimator:")
        self.logger.info(f"  OFDM size: {self.ofdm_size}")
        self.logger.info(f"  Pilot size: {self.pilot_size}")
        self.logger.info(f"  Input features: {in_feature_dim}")
        self.logger.info(f"  Output features: {out_feature_dim}")
        self.logger.info(f"  Device: {self.device}")

        # Create linear layer
        self.linear = nn.Linear(in_feature_dim, out_feature_dim)
        self.to(self.device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the MMSE estimator.

        Args:
            x: Input tensor containing pilot signals with shape
               (batch_size, pilot_size[0], pilot_size[1])

        Returns:
            Estimated OFDM signal tensor with shape
            (batch_size, ofdm_size[0], ofdm_size[1])
        """
        # pytorch does nothing if input is already on correct device
        x = x.to(self.device)
        self.logger.debug(f"Input shape: {x.size()}")

        # Validate input shape
        expected_shape = (x.size(0), self.pilot_size[0], self.pilot_size[1])
        if x.size() != expected_shape:
            raise ValueError(
                f"Expected input shape {expected_shape}, got {x.size()}"
            )

        # Flatten input for linear transformation
        x = torch.flatten(x, start_dim=1)
        self.logger.debug(f"Flattened shape: {x.size()}")

        # Apply linear transformation
        x = self.linear(x)
        self.logger.debug(f"Linear output shape: {x.size()}")

        # Reshape to OFDM dimensions
        x = x.reshape(-1, self.ofdm_size[0], self.ofdm_size[1])
        self.logger.debug(f"Reshaped output shape: {x.size()}")

        return x

    def __repr__(self) -> str:
        """String representation of the estimator."""
        return (
            f"LinearEstimator(\n"
            f"  ofdm_size={self.ofdm_size},\n"
            f"  pilot_size={self.pilot_size},\n"
            f"  device={self.device}\n"
            f")"
        )