File size: 11,081 Bytes
54d5c08 0c90d5f 54d5c08 0c90d5f 54d5c08 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 9727e5e 0c90d5f 54d5c08 0c90d5f 54d5c08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 |
"""
Utility functions for OFDM channel estimation.
This module provides various utility functions for processing, visualizing,
and analyzing OFDM channel estimation data, including complex channel matrices,
error calculations, model statistics, and visualization tools for
performance evaluation.
"""
from pathlib import Path
from typing import Optional, Union
import re
import os
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from prettytable import PrettyTable
import torch
class EarlyStopping:
"""Handles early stopping logic for training.
Monitors validation loss during training and signals when to stop
training if the loss has not improved for a specified number of epochs.
Attributes:
patience: Number of epochs to wait before stopping training
remaining_patience: Current remaining patience counter
min_loss: Minimum validation loss observed so far
"""
def __init__(self, patience: int = 3):
"""
Initialize early stopping.
Args:
patience: Number of epochs to wait before stopping
"""
self.patience = patience
self.remaining_patience = patience
self.min_loss: Optional[float] = None
def early_stop(self, loss: float) -> bool:
"""
Check if training should stop.
Args:
loss: Current validation loss
Returns:
Whether to stop training
"""
if self.min_loss is None:
self.min_loss = loss
return False
if loss < self.min_loss:
self.min_loss = loss
self.remaining_patience = self.patience
return False
self.remaining_patience -= 1
return self.remaining_patience == 0
def extract_values(file_name):
"""
Extract channel information from a file name.
Parses file names with format:
'{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat'
Example:
For filename "1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat":
- file_number: 1
- snr: 20 (Signal-to-Noise Ratio in dB)
- delay_spread: 50 (Delay Spread)
- doppler: 500 (Maximum Doppler Shift)
- pilot_freq: 3 (Pilot placement frequency)
- channel_type: TDL-A (Channel model type)
Args:
file_name: The file name from which to extract channel information
Returns:
tuple: A tuple containing:
- file_number (torch.Tensor): The file number (sequential identifier)
- snr (torch.Tensor): Signal-to-noise ratio value in dB
- delay_spread (torch.Tensor): Delay spread value
- max_doppler_shift (torch.Tensor): Maximum Doppler shift value
- pilot_placement_frequency (torch.Tensor): Pilot placement frequency
- channel_type (list): The channel type (e.g., ['TDL-A'])
Raises:
ValueError: If the file name does not match the expected pattern
"""
pattern = r'(\d+)_SNR-(\d+)_DS-(\d+)_DOP-(\d+)_N-(\d+)_([A-Z\-]+)\.mat'
match = re.match(pattern, file_name)
if match:
file_no = torch.tensor([int(match.group(1))], dtype=torch.float)
snr_value = torch.tensor([int(match.group(2))], dtype=torch.float)
ds_value = torch.tensor([int(match.group(3))], dtype=torch.float)
dop_value = torch.tensor([int(match.group(4))], dtype=torch.float)
n = torch.tensor([int(match.group(5))], dtype=torch.float)
channel_type = [match.group(6)]
return file_no, snr_value, ds_value, dop_value, n, channel_type
else:
raise ValueError("Cannot extract file information.")
def get_error_images(variable, channel_data, show=False):
"""
Create visualizations of channel estimation errors.
Generates a figure with error heatmaps for different channel conditions,
showing the absolute difference between estimated and ideal channels.
Args:
variable: Name of the variable being visualized (e.g., 'SNR', 'DS')
channel_data: Dictionary mapping parameter values to dictionaries
containing 'estimated_channel' and 'ideal_channel'
show: Whether to display the figure immediately (default: False)
Returns:
matplotlib.figure.Figure: The generated figure with error heatmaps
"""
# Create a figure with 7 subplots
fig, axes = plt.subplots(1, len(channel_data), figsize=(20, 6))
# Plot each subplot with consistent color scaling
for i, (key, channels) in enumerate(channel_data.items()):
# Calculate absolute error between estimated and ideal channels
estimated_channel = channels['estimated_channel']
ideal_channel = channels['ideal_channel']
error_matrix = torch.abs(estimated_channel - ideal_channel)
error_numpy = error_matrix.detach().cpu().numpy()
# Plot in the corresponding subplot with shared colormap limits
ax = axes[i]
cax = ax.imshow(error_numpy, cmap='viridis', aspect=14 / 120, vmin=0, vmax=1)
ax.set_title(f"{variable} = {key}")
ax.set_xlabel('Columns (14)')
ax.set_ylabel('Rows (120)')
# Create a new axis for the color bar to the right of the subplots
cbar_ax = fig.add_axes((0.92, 0.15, 0.02, 0.7)) # [left, bottom, width, height]
fig.colorbar(cax, cax=cbar_ax, label='Error Magnitude')
# Adjust layout to prevent overlapping labels
fig.tight_layout(rect=(0, 0, 0.9, 1)) # Leave space for the color bar on the right
# Show the figure if `show` is True
if show:
plt.show()
# Return the main figure
return fig
def concat_complex_channel(channel_matrix):
"""
Convert a complex channel matrix into a real matrix by concatenating real and imaginary parts.
Transforms a complex tensor into a real-valued tensor by concatenating
the real and imaginary components along the specified dimension.
Args:
channel_matrix: Complex channel matrix
Returns:
Real-valued channel matrix with concatenated real and imaginary parts
"""
real_channel_m = torch.real(channel_matrix)
imag_channel_m = torch.imag(channel_matrix)
cat_channel_m = torch.cat((real_channel_m, imag_channel_m), dim=1)
return cat_channel_m
def get_test_stats_plot(x_name, stats, methods, show=False):
"""
Plot test statistics for multiple methods as line graphs.
Creates a line plot comparing performance metrics (e.g., MSE) across
different conditions or parameters for multiple methods.
Args:
x_name: Label for the x-axis (e.g., 'SNR', 'DS', 'Epoch')
stats: List of dictionaries where each dictionary maps x-values to
performance metrics for a specific method
methods: List of method names corresponding to each entry in stats
show: Whether to display the plot immediately (default: False)
Returns:
matplotlib.figure.Figure: The generated figure object
Raises:
AssertionError: If stats and methods lists have different lengths
"""
assert len(stats) == len(methods), "Provided stats and methods do not have the same length."
fig = plt.figure()
symbols = iter(["*", "x", "+", "D", "v", "^"])
for stat in stats:
try:
symbol = next(symbols)
except StopIteration:
symbols = iter(["o", "*", "x", "+", "D", "v", "^"])
symbol = next(symbols)
kv_pairs = sorted(list(stat.items()), key=lambda x: x[0])
x_vals = []
y_vals = []
for key, value in kv_pairs:
x_vals.append(key)
y_vals.append(value)
plt.plot(x_vals, y_vals, f"{symbol}--")
plt.xlabel(x_name)
plt.ylabel("MSE Error (dB)")
plt.grid()
plt.legend(methods)
if show:
plt.show()
return fig
def to_db(val):
"""
Convert values to decibels (dB).
Applies the formula 10 * log10(val) to convert values to the decibel scale.
Args:
val: Input value or array to convert to dB (must be positive)
Returns:
The input value(s) converted to decibels
"""
return 10 * np.log10(val)
def mse(x, y):
"""
Calculate mean squared error (MSE) in dB between two complex arrays.
Computes the average squared magnitude of the difference between
two complex arrays and converts the result to decibels.
Args:
x: First complex numpy array
y: Second complex numpy array (same shape as x)
Returns:
MSE in decibels (dB) between the two arrays
"""
mse_xy = np.mean(np.square(np.abs(x - y)))
mse_xy_db = to_db(mse_xy)
return mse_xy_db
def get_ls_mse_per_folder(folders_dir: Union[Path, str]):
"""
Calculate average MSE for LS estimates in each subfolder.
For each subfolder in the specified directory, calculates the average
mean squared error between least-squares channel estimates and ideal
channel values across all .mat files in that subfolder.
Args:
folders_dir: Path to directory containing subfolders with .mat files
Each subfolder should be named 'prefix_val' where val is an integer
Returns:
Dictionary mapping integer values from subfolder names to average MSE values in dB
Notes:
- Each .mat file should contain a 3D matrix 'H' where:
- H[:,:,0] is the ideal channel
- H[:,:,2] is the LS channel estimate
- Subfolders are sorted by the integer in their names
"""
mse_sums = {}
folders = os.listdir(folders_dir)
folders = sorted(folders, key=lambda x: int(x.split("_")[1]))
for folder in folders:
_, val = folder.split("_")
mse_sum = 0
folder_size = len(os.listdir(os.path.join(folders_dir, folder)))
for file in os.listdir(os.path.join(folders_dir, folder)):
mat_data = sio.loadmat(os.path.join(folders_dir, folder, file))['H']
ls_estimate = mat_data[:, :, 2]
ideal = mat_data[:, :, 0]
mse_sum += mse(ls_estimate, ideal)
mse_sum /= folder_size
mse_sums[int(val)] = mse_sum
return mse_sums
def get_model_details(model):
"""
Get parameter counts and structure details for a PyTorch model.
Analyzes a PyTorch model to determine the total number of trainable
parameters and creates a formatted table showing the parameter count
for each named parameter in the model.
Args:
model: PyTorch model to analyze
Returns:
tuple containing:
- total_params: Total number of trainable parameters
- table: PrettyTable showing parameter counts by module
"""
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
params = parameter.numel()
table.add_row([name, params])
total_params += params
return total_params, table
|