File size: 10,558 Bytes
9727e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c90d5f
 
 
 
 
 
9727e5e
0c90d5f
 
 
 
 
 
9727e5e
0c90d5f
 
 
 
 
 
 
 
9727e5e
 
0c90d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9727e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
0c90d5f
 
 
 
 
 
 
 
 
 
 
9727e5e
0c90d5f
 
 
 
 
9727e5e
0c90d5f
 
9727e5e
0c90d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9727e5e
0c90d5f
 
 
 
 
 
 
 
9727e5e
 
 
0c90d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9727e5e
 
 
 
0c90d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
9727e5e
0c90d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9727e5e
 
0c90d5f
 
 
 
 
 
 
9727e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c90d5f
 
9727e5e
 
0c90d5f
 
 
 
 
 
9727e5e
0c90d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
9727e5e
0c90d5f
 
 
 
 
 
 
 
9727e5e
 
0c90d5f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
"""Module for loading and processing .mat files containing channel estimates for PyTorch.

This module expects .mat files with a specific naming convention and internal structure:

File Naming Convention:
    Files must follow the pattern: {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
    
    Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
    - file_number: Sequential file identifier
    - SNR: Signal-to-Noise Ratio in dB
    - DS: Delay Spread
    - DOP: Maximum Doppler Shift
    - N: Pilot placement frequency
    - channel_type: Channel model type (e.g., TDL-A)

File Content Structure:
    Each .mat file must contain a variable 'H' with shape [subcarriers, symbols, 3]:
    - H[:, :, 0]: Ground truth channel (complex values)
    - H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
    - H[:, :, 2]: Unused (reserved for future use)

The dataset extracts pilot values from the LS estimates and provides metadata from the filename
for adaptive channel estimation models.
"""
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union

import scipy.io as sio
import torch
from torch.utils.data import Dataset, DataLoader
from pydantic import BaseModel, Field

from src.utils import extract_values

__all__ = ['MatDataset', 'get_test_dataloaders']


class PilotDimensions(BaseModel):
    """Container for pilot signal dimensions.

    Stores and validates the dimensions of pilot signals used in channel estimation.

    Attributes:
        num_subcarriers: Number of subcarriers in the pilot signal
        num_ofdm_symbols: Number of OFDM symbols in the pilot signal
    """
    num_subcarriers: int = Field(..., gt=0, description="Number of subcarriers in the pilot signal")
    num_ofdm_symbols: int = Field(..., gt=0, description="Number of OFDM symbols in the pilot signal")

    def as_tuple(self) -> Tuple[int, int]:
        """Return dimensions as a tuple.

        Returns:
            Tuple of (num_subcarriers, num_ofdm_symbols)
        """
        return self.num_subcarriers, self.num_ofdm_symbols


class MatDataset(Dataset):
    """Dataset for loading and formatting .mat files containing channel estimates.

    Processes .mat files containing channel estimation data and converts them into
    PyTorch complex tensors for channel estimation tasks.

    Expected File Format:
        - Files must be named according to the pattern: 
          {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
        - Each .mat file must contain a variable 'H' with shape [subcarriers, symbols, 3]
        - H[:, :, 0]: Ground truth channel (complex values)
        - H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
        - H[:, :, 2]: Bilinear interpolated LS channel estimate

    Returns:
        For each sample, returns a tuple of:
        - Pilot LS channel estimate (complex tensor, shape [pilot_subcarriers, pilot_symbols])
        - Ground truth channel estimate (complex tensor, shape [ofdm_subcarriers, ofdm_symbols])
        - Metadata tuple: (file_number, snr, delay_spread, doppler, pilot_freq, channel_type)
    """

    def __init__(
            self,
            data_dir: Union[str, Path],
            pilot_dims: List[int],
            transform: Optional[Callable] = None
    ) -> None:
        """Initialize the MatDataset.

        Args:
            data_dir: Path to the directory containing the dataset (should contain .mat files).
            pilot_dims: Dimensions of pilot data as [num_subcarriers, num_ofdm_symbols].
            transform: Optional transformation to apply to samples.

        Raises:
            FileNotFoundError: If data_dir doesn't exist.
            ValueError: If no .mat files are found in data_dir.
        """
        self.data_dir = Path(data_dir)
        self.pilot_dims = PilotDimensions(num_subcarriers=pilot_dims[0], num_ofdm_symbols=pilot_dims[1])
        self.transform = transform

        if not self.data_dir.exists():
            raise FileNotFoundError(f"Data directory not found: {self.data_dir}")

        self.file_list = list(self.data_dir.glob("*.mat"))
        if not self.file_list:
            raise ValueError(f"No .mat files found in {self.data_dir}")

    def __len__(self) -> int:
        """Return the total number of files in the dataset.

        Returns:
            Integer count of .mat files in the dataset directory
        """
        return len(self.file_list)

    def _process_channel_data(
            self,
            mat_data: dict
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Process channel data and extract pilot values from LS estimates.

        Extracts pilot values from LS channel estimates with zero entries removed,
        returning complex-valued tensors for both estimate and ground truth.

        Args:
            mat_data: Loaded .mat file data containing 'H' variable

        Returns:
            Tuple of (pilot LS estimate, ground truth channel)

        Raises:
            ValueError: If the data format is unexpected or processing fails
        """
        try:
            # Extract ground truth channel
            h_ideal = torch.tensor(mat_data['H'][:, :, 0], dtype=torch.cfloat)
            
            # Extract LS channel estimate with zero entries
            hzero_ls = torch.tensor(mat_data['H'][:, :, 1], dtype=torch.cfloat)

            # Remove zero entries, keep only pilot values
            zero_complex = torch.complex(torch.tensor(0.0), torch.tensor(0.0))
            hp_ls = hzero_ls[hzero_ls != zero_complex]

            # Validate expected number of pilot values
            expected_pilots = self.pilot_dims.num_subcarriers * self.pilot_dims.num_ofdm_symbols
            if hp_ls.numel() != expected_pilots:
                raise ValueError(
                    f"Expected {expected_pilots} pilot values, got {hp_ls.numel()}"
                )

            # Reshape to pilot grid dimensions [subcarriers, symbols]
            hp_ls = hp_ls.unsqueeze(dim=1).view(
                self.pilot_dims.num_ofdm_symbols,
                self.pilot_dims.num_subcarriers
            ).t()

            return hp_ls, h_ideal

        except Exception as e:
            raise ValueError(f"Error processing channel data: {str(e)}")

    def __getitem__(
            self,
            idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
        """Load and process a .mat file at the given index.

        Args:
            idx: Index of the file to load.

        Returns:
            Tuple containing:
                - Pilot LS channel estimate (complex tensor, shape [pilot_subcarriers, pilot_symbols])
                - Ground truth channel estimate (complex tensor, shape [ofdm_subcarriers, ofdm_symbols])
                - Metadata tuple: (file_number, snr, delay_spread, doppler, pilot_freq, channel_type)
                  All metadata values are torch.Tensor except channel_type which is a list

        Raises:
            ValueError: If file format is invalid or processing fails.
            IndexError: If idx is out of range.
        """
        if not 0 <= idx < len(self):
            raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}")

        try:
            mat_data = sio.loadmat(self.file_list[idx])
            if 'H' not in mat_data or mat_data['H'].shape[-1] < 3:
                raise ValueError("Invalid .mat file format: missing required data")

            # Process channel data to extract pilot estimates
            h_est, h_ideal = self._process_channel_data(mat_data)

            # Extract metadata from filename
            meta_data = extract_values(self.file_list[idx].name)
            if meta_data is None:
                raise ValueError(f"Unrecognized filename format: {self.file_list[idx].name}")

            # Apply optional transforms
            if self.transform:
                h_est = self.transform(h_est)
                h_ideal = self.transform(h_ideal)

            return h_est, h_ideal, meta_data

        except Exception as e:
            raise ValueError(f"Error processing file {self.file_list[idx]}: {str(e)}")


def get_test_dataloaders(
        dataset_dir: Union[str, Path],
        pilot_dims: List[int],
        batch_size: int
) -> List[Tuple[str, DataLoader]]:
    """Create DataLoaders for each subdirectory in the dataset directory.

    Automatically discovers and creates appropriate DataLoader instances for
    all subdirectories in the specified dataset directory, useful for testing
    across multiple test conditions or scenarios.

    Expected Directory Structure:
        dataset_dir/
        β”œβ”€β”€ DS_50/          # Delay Spread = 50
        β”‚   β”œβ”€β”€ 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
        β”‚   β”œβ”€β”€ 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
        β”‚   └── ...
        β”œβ”€β”€ DS_100/         # Delay Spread = 100
        β”‚   β”œβ”€β”€ 1_SNR-20_DS-100_DOP-500_N-3_TDL-A.mat
        β”‚   └── ...
        β”œβ”€β”€ SNR_10/         # SNR = 10 dB
        β”‚   β”œβ”€β”€ 1_SNR-10_DS-50_DOP-500_N-3_TDL-A.mat
        β”‚   └── ...
        └── ...

    Each subdirectory should contain .mat files with the naming convention:
    {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat

    Args:
        dataset_dir: Path to main directory containing dataset subdirectories
        pilot_dims: List of [num_subcarriers, num_ofdm_symbols] for pilot dimensions
        batch_size: Number of samples per batch

    Returns:
        List of tuples containing (subdirectory_name, corresponding_dataloader)

    Raises:
        FileNotFoundError: If dataset_dir doesn't exist
        ValueError: If no valid subdirectories are found
    """
    dataset_dir = Path(dataset_dir)
    if not dataset_dir.exists():
        raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")

    subdirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
    if not subdirs:
        raise ValueError(f"No subdirectories found in {dataset_dir}")

    test_datasets = [
        (
            subdir.name,
            MatDataset(
                subdir,
                pilot_dims
            )
        )
        for subdir in subdirs
    ]

    return [
        (name, DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,  # no shuffling for testing
            num_workers=0
        ))
        for name, dataset in test_datasets
    ]