File size: 10,558 Bytes
9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 |
"""Module for loading and processing .mat files containing channel estimates for PyTorch.
This module expects .mat files with a specific naming convention and internal structure:
File Naming Convention:
Files must follow the pattern: {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
- file_number: Sequential file identifier
- SNR: Signal-to-Noise Ratio in dB
- DS: Delay Spread
- DOP: Maximum Doppler Shift
- N: Pilot placement frequency
- channel_type: Channel model type (e.g., TDL-A)
File Content Structure:
Each .mat file must contain a variable 'H' with shape [subcarriers, symbols, 3]:
- H[:, :, 0]: Ground truth channel (complex values)
- H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
- H[:, :, 2]: Unused (reserved for future use)
The dataset extracts pilot values from the LS estimates and provides metadata from the filename
for adaptive channel estimation models.
"""
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
import scipy.io as sio
import torch
from torch.utils.data import Dataset, DataLoader
from pydantic import BaseModel, Field
from src.utils import extract_values
__all__ = ['MatDataset', 'get_test_dataloaders']
class PilotDimensions(BaseModel):
"""Container for pilot signal dimensions.
Stores and validates the dimensions of pilot signals used in channel estimation.
Attributes:
num_subcarriers: Number of subcarriers in the pilot signal
num_ofdm_symbols: Number of OFDM symbols in the pilot signal
"""
num_subcarriers: int = Field(..., gt=0, description="Number of subcarriers in the pilot signal")
num_ofdm_symbols: int = Field(..., gt=0, description="Number of OFDM symbols in the pilot signal")
def as_tuple(self) -> Tuple[int, int]:
"""Return dimensions as a tuple.
Returns:
Tuple of (num_subcarriers, num_ofdm_symbols)
"""
return self.num_subcarriers, self.num_ofdm_symbols
class MatDataset(Dataset):
"""Dataset for loading and formatting .mat files containing channel estimates.
Processes .mat files containing channel estimation data and converts them into
PyTorch complex tensors for channel estimation tasks.
Expected File Format:
- Files must be named according to the pattern:
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
- Each .mat file must contain a variable 'H' with shape [subcarriers, symbols, 3]
- H[:, :, 0]: Ground truth channel (complex values)
- H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
- H[:, :, 2]: Bilinear interpolated LS channel estimate
Returns:
For each sample, returns a tuple of:
- Pilot LS channel estimate (complex tensor, shape [pilot_subcarriers, pilot_symbols])
- Ground truth channel estimate (complex tensor, shape [ofdm_subcarriers, ofdm_symbols])
- Metadata tuple: (file_number, snr, delay_spread, doppler, pilot_freq, channel_type)
"""
def __init__(
self,
data_dir: Union[str, Path],
pilot_dims: List[int],
transform: Optional[Callable] = None
) -> None:
"""Initialize the MatDataset.
Args:
data_dir: Path to the directory containing the dataset (should contain .mat files).
pilot_dims: Dimensions of pilot data as [num_subcarriers, num_ofdm_symbols].
transform: Optional transformation to apply to samples.
Raises:
FileNotFoundError: If data_dir doesn't exist.
ValueError: If no .mat files are found in data_dir.
"""
self.data_dir = Path(data_dir)
self.pilot_dims = PilotDimensions(num_subcarriers=pilot_dims[0], num_ofdm_symbols=pilot_dims[1])
self.transform = transform
if not self.data_dir.exists():
raise FileNotFoundError(f"Data directory not found: {self.data_dir}")
self.file_list = list(self.data_dir.glob("*.mat"))
if not self.file_list:
raise ValueError(f"No .mat files found in {self.data_dir}")
def __len__(self) -> int:
"""Return the total number of files in the dataset.
Returns:
Integer count of .mat files in the dataset directory
"""
return len(self.file_list)
def _process_channel_data(
self,
mat_data: dict
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Process channel data and extract pilot values from LS estimates.
Extracts pilot values from LS channel estimates with zero entries removed,
returning complex-valued tensors for both estimate and ground truth.
Args:
mat_data: Loaded .mat file data containing 'H' variable
Returns:
Tuple of (pilot LS estimate, ground truth channel)
Raises:
ValueError: If the data format is unexpected or processing fails
"""
try:
# Extract ground truth channel
h_ideal = torch.tensor(mat_data['H'][:, :, 0], dtype=torch.cfloat)
# Extract LS channel estimate with zero entries
hzero_ls = torch.tensor(mat_data['H'][:, :, 1], dtype=torch.cfloat)
# Remove zero entries, keep only pilot values
zero_complex = torch.complex(torch.tensor(0.0), torch.tensor(0.0))
hp_ls = hzero_ls[hzero_ls != zero_complex]
# Validate expected number of pilot values
expected_pilots = self.pilot_dims.num_subcarriers * self.pilot_dims.num_ofdm_symbols
if hp_ls.numel() != expected_pilots:
raise ValueError(
f"Expected {expected_pilots} pilot values, got {hp_ls.numel()}"
)
# Reshape to pilot grid dimensions [subcarriers, symbols]
hp_ls = hp_ls.unsqueeze(dim=1).view(
self.pilot_dims.num_ofdm_symbols,
self.pilot_dims.num_subcarriers
).t()
return hp_ls, h_ideal
except Exception as e:
raise ValueError(f"Error processing channel data: {str(e)}")
def __getitem__(
self,
idx: int
) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
"""Load and process a .mat file at the given index.
Args:
idx: Index of the file to load.
Returns:
Tuple containing:
- Pilot LS channel estimate (complex tensor, shape [pilot_subcarriers, pilot_symbols])
- Ground truth channel estimate (complex tensor, shape [ofdm_subcarriers, ofdm_symbols])
- Metadata tuple: (file_number, snr, delay_spread, doppler, pilot_freq, channel_type)
All metadata values are torch.Tensor except channel_type which is a list
Raises:
ValueError: If file format is invalid or processing fails.
IndexError: If idx is out of range.
"""
if not 0 <= idx < len(self):
raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}")
try:
mat_data = sio.loadmat(self.file_list[idx])
if 'H' not in mat_data or mat_data['H'].shape[-1] < 3:
raise ValueError("Invalid .mat file format: missing required data")
# Process channel data to extract pilot estimates
h_est, h_ideal = self._process_channel_data(mat_data)
# Extract metadata from filename
meta_data = extract_values(self.file_list[idx].name)
if meta_data is None:
raise ValueError(f"Unrecognized filename format: {self.file_list[idx].name}")
# Apply optional transforms
if self.transform:
h_est = self.transform(h_est)
h_ideal = self.transform(h_ideal)
return h_est, h_ideal, meta_data
except Exception as e:
raise ValueError(f"Error processing file {self.file_list[idx]}: {str(e)}")
def get_test_dataloaders(
dataset_dir: Union[str, Path],
pilot_dims: List[int],
batch_size: int
) -> List[Tuple[str, DataLoader]]:
"""Create DataLoaders for each subdirectory in the dataset directory.
Automatically discovers and creates appropriate DataLoader instances for
all subdirectories in the specified dataset directory, useful for testing
across multiple test conditions or scenarios.
Expected Directory Structure:
dataset_dir/
βββ DS_50/ # Delay Spread = 50
β βββ 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
β βββ 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
β βββ ...
βββ DS_100/ # Delay Spread = 100
β βββ 1_SNR-20_DS-100_DOP-500_N-3_TDL-A.mat
β βββ ...
βββ SNR_10/ # SNR = 10 dB
β βββ 1_SNR-10_DS-50_DOP-500_N-3_TDL-A.mat
β βββ ...
βββ ...
Each subdirectory should contain .mat files with the naming convention:
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
Args:
dataset_dir: Path to main directory containing dataset subdirectories
pilot_dims: List of [num_subcarriers, num_ofdm_symbols] for pilot dimensions
batch_size: Number of samples per batch
Returns:
List of tuples containing (subdirectory_name, corresponding_dataloader)
Raises:
FileNotFoundError: If dataset_dir doesn't exist
ValueError: If no valid subdirectories are found
"""
dataset_dir = Path(dataset_dir)
if not dataset_dir.exists():
raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
subdirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
if not subdirs:
raise ValueError(f"No subdirectories found in {dataset_dir}")
test_datasets = [
(
subdir.name,
MatDataset(
subdir,
pilot_dims
)
)
for subdir in subdirs
]
return [
(name, DataLoader(
dataset,
batch_size=batch_size,
shuffle=False, # no shuffling for testing
num_workers=0
))
for name, dataset in test_datasets
] |