|
""" |
|
Utility functions for OFDM channel estimation. |
|
|
|
This module provides various utility functions for processing, visualizing, |
|
and analyzing OFDM channel estimation data, including complex channel matrices, |
|
error calculations, model statistics, and visualization tools for |
|
performance evaluation. |
|
""" |
|
|
|
from pathlib import Path |
|
from typing import Optional, Union |
|
import re |
|
import os |
|
|
|
import numpy as np |
|
import scipy.io as sio |
|
import matplotlib.pyplot as plt |
|
from prettytable import PrettyTable |
|
import torch |
|
|
|
|
|
class EarlyStopping: |
|
"""Handles early stopping logic for training. |
|
|
|
Monitors validation loss during training and signals when to stop |
|
training if the loss has not improved for a specified number of epochs. |
|
|
|
Attributes: |
|
patience: Number of epochs to wait before stopping training |
|
remaining_patience: Current remaining patience counter |
|
min_loss: Minimum validation loss observed so far |
|
""" |
|
|
|
def __init__(self, patience: int = 3): |
|
""" |
|
Initialize early stopping. |
|
|
|
Args: |
|
patience: Number of epochs to wait before stopping |
|
""" |
|
self.patience = patience |
|
self.remaining_patience = patience |
|
self.min_loss: Optional[float] = None |
|
|
|
def early_stop(self, loss: float) -> bool: |
|
""" |
|
Check if training should stop. |
|
|
|
Args: |
|
loss: Current validation loss |
|
|
|
Returns: |
|
Whether to stop training |
|
""" |
|
if self.min_loss is None: |
|
self.min_loss = loss |
|
return False |
|
|
|
if loss < self.min_loss: |
|
self.min_loss = loss |
|
self.remaining_patience = self.patience |
|
return False |
|
|
|
self.remaining_patience -= 1 |
|
return self.remaining_patience == 0 |
|
|
|
|
|
def extract_values(file_name): |
|
""" |
|
Extract channel information from a file name. |
|
|
|
Parses file names with format: |
|
'{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat' |
|
|
|
Example: |
|
For filename "1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat": |
|
- file_number: 1 |
|
- snr: 20 (Signal-to-Noise Ratio in dB) |
|
- delay_spread: 50 (Delay Spread) |
|
- doppler: 500 (Maximum Doppler Shift) |
|
- pilot_freq: 3 (Pilot placement frequency) |
|
- channel_type: TDL-A (Channel model type) |
|
|
|
Args: |
|
file_name: The file name from which to extract channel information |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- file_number (torch.Tensor): The file number (sequential identifier) |
|
- snr (torch.Tensor): Signal-to-noise ratio value in dB |
|
- delay_spread (torch.Tensor): Delay spread value |
|
- max_doppler_shift (torch.Tensor): Maximum Doppler shift value |
|
- pilot_placement_frequency (torch.Tensor): Pilot placement frequency |
|
- channel_type (list): The channel type (e.g., ['TDL-A']) |
|
|
|
Raises: |
|
ValueError: If the file name does not match the expected pattern |
|
""" |
|
pattern = r'(\d+)_SNR-(\d+)_DS-(\d+)_DOP-(\d+)_N-(\d+)_([A-Z\-]+)\.mat' |
|
match = re.match(pattern, file_name) |
|
if match: |
|
file_no = torch.tensor([int(match.group(1))], dtype=torch.float) |
|
snr_value = torch.tensor([int(match.group(2))], dtype=torch.float) |
|
ds_value = torch.tensor([int(match.group(3))], dtype=torch.float) |
|
dop_value = torch.tensor([int(match.group(4))], dtype=torch.float) |
|
n = torch.tensor([int(match.group(5))], dtype=torch.float) |
|
channel_type = [match.group(6)] |
|
return file_no, snr_value, ds_value, dop_value, n, channel_type |
|
else: |
|
raise ValueError("Cannot extract file information.") |
|
|
|
|
|
def get_error_images(variable, channel_data, show=False): |
|
""" |
|
Create visualizations of channel estimation errors. |
|
|
|
Generates a figure with error heatmaps for different channel conditions, |
|
showing the absolute difference between estimated and ideal channels. |
|
|
|
Args: |
|
variable: Name of the variable being visualized (e.g., 'SNR', 'DS') |
|
channel_data: Dictionary mapping parameter values to dictionaries |
|
containing 'estimated_channel' and 'ideal_channel' |
|
show: Whether to display the figure immediately (default: False) |
|
|
|
Returns: |
|
matplotlib.figure.Figure: The generated figure with error heatmaps |
|
""" |
|
|
|
fig, axes = plt.subplots(1, len(channel_data), figsize=(20, 6)) |
|
|
|
|
|
for i, (key, channels) in enumerate(channel_data.items()): |
|
|
|
|
|
estimated_channel = channels['estimated_channel'] |
|
ideal_channel = channels['ideal_channel'] |
|
|
|
error_matrix = torch.abs(estimated_channel - ideal_channel) |
|
error_numpy = error_matrix.detach().cpu().numpy() |
|
|
|
|
|
ax = axes[i] |
|
cax = ax.imshow(error_numpy, cmap='viridis', aspect=14 / 120, vmin=0, vmax=1) |
|
ax.set_title(f"{variable} = {key}") |
|
ax.set_xlabel('Columns (14)') |
|
ax.set_ylabel('Rows (120)') |
|
|
|
|
|
cbar_ax = fig.add_axes((0.92, 0.15, 0.02, 0.7)) |
|
fig.colorbar(cax, cax=cbar_ax, label='Error Magnitude') |
|
|
|
|
|
fig.tight_layout(rect=(0, 0, 0.9, 1)) |
|
|
|
|
|
if show: |
|
plt.show() |
|
|
|
|
|
return fig |
|
|
|
|
|
def concat_complex_channel(channel_matrix): |
|
""" |
|
Convert a complex channel matrix into a real matrix by concatenating real and imaginary parts. |
|
|
|
Transforms a complex tensor into a real-valued tensor by concatenating |
|
the real and imaginary components along the specified dimension. |
|
|
|
Args: |
|
channel_matrix: Complex channel matrix |
|
|
|
Returns: |
|
Real-valued channel matrix with concatenated real and imaginary parts |
|
""" |
|
real_channel_m = torch.real(channel_matrix) |
|
imag_channel_m = torch.imag(channel_matrix) |
|
cat_channel_m = torch.cat((real_channel_m, imag_channel_m), dim=1) |
|
return cat_channel_m |
|
|
|
|
|
|
|
|
|
|
|
def get_test_stats_plot(x_name, stats, methods, show=False): |
|
""" |
|
Plot test statistics for multiple methods as line graphs. |
|
|
|
Creates a line plot comparing performance metrics (e.g., MSE) across |
|
different conditions or parameters for multiple methods. |
|
|
|
Args: |
|
x_name: Label for the x-axis (e.g., 'SNR', 'DS', 'Epoch') |
|
stats: List of dictionaries where each dictionary maps x-values to |
|
performance metrics for a specific method |
|
methods: List of method names corresponding to each entry in stats |
|
show: Whether to display the plot immediately (default: False) |
|
|
|
Returns: |
|
matplotlib.figure.Figure: The generated figure object |
|
|
|
Raises: |
|
AssertionError: If stats and methods lists have different lengths |
|
""" |
|
assert len(stats) == len(methods), "Provided stats and methods do not have the same length." |
|
fig = plt.figure() |
|
symbols = iter(["*", "x", "+", "D", "v", "^"]) |
|
for stat in stats: |
|
try: |
|
symbol = next(symbols) |
|
except StopIteration: |
|
symbols = iter(["o", "*", "x", "+", "D", "v", "^"]) |
|
symbol = next(symbols) |
|
|
|
kv_pairs = sorted(list(stat.items()), key=lambda x: x[0]) |
|
x_vals = [] |
|
y_vals = [] |
|
for key, value in kv_pairs: |
|
x_vals.append(key) |
|
y_vals.append(value) |
|
|
|
plt.plot(x_vals, y_vals, f"{symbol}--") |
|
plt.xlabel(x_name) |
|
plt.ylabel("MSE Error (dB)") |
|
plt.grid() |
|
plt.legend(methods) |
|
if show: |
|
plt.show() |
|
return fig |
|
|
|
|
|
def to_db(val): |
|
""" |
|
Convert values to decibels (dB). |
|
|
|
Applies the formula 10 * log10(val) to convert values to the decibel scale. |
|
|
|
Args: |
|
val: Input value or array to convert to dB (must be positive) |
|
|
|
Returns: |
|
The input value(s) converted to decibels |
|
""" |
|
return 10 * np.log10(val) |
|
|
|
|
|
def mse(x, y): |
|
""" |
|
Calculate mean squared error (MSE) in dB between two complex arrays. |
|
|
|
Computes the average squared magnitude of the difference between |
|
two complex arrays and converts the result to decibels. |
|
|
|
Args: |
|
x: First complex numpy array |
|
y: Second complex numpy array (same shape as x) |
|
|
|
Returns: |
|
MSE in decibels (dB) between the two arrays |
|
""" |
|
mse_xy = np.mean(np.square(np.abs(x - y))) |
|
mse_xy_db = to_db(mse_xy) |
|
return mse_xy_db |
|
|
|
|
|
def get_ls_mse_per_folder(folders_dir: Union[Path, str]): |
|
""" |
|
Calculate average MSE for LS estimates in each subfolder. |
|
|
|
For each subfolder in the specified directory, calculates the average |
|
mean squared error between least-squares channel estimates and ideal |
|
channel values across all .mat files in that subfolder. |
|
|
|
Args: |
|
folders_dir: Path to directory containing subfolders with .mat files |
|
Each subfolder should be named 'prefix_val' where val is an integer |
|
|
|
Returns: |
|
Dictionary mapping integer values from subfolder names to average MSE values in dB |
|
|
|
Notes: |
|
- Each .mat file should contain a 3D matrix 'H' where: |
|
- H[:,:,0] is the ideal channel |
|
- H[:,:,2] is the LS channel estimate |
|
- Subfolders are sorted by the integer in their names |
|
""" |
|
|
|
mse_sums = {} |
|
folders = os.listdir(folders_dir) |
|
folders = sorted(folders, key=lambda x: int(x.split("_")[1])) |
|
for folder in folders: |
|
_, val = folder.split("_") |
|
mse_sum = 0 |
|
folder_size = len(os.listdir(os.path.join(folders_dir, folder))) |
|
for file in os.listdir(os.path.join(folders_dir, folder)): |
|
mat_data = sio.loadmat(os.path.join(folders_dir, folder, file))['H'] |
|
ls_estimate = mat_data[:, :, 2] |
|
ideal = mat_data[:, :, 0] |
|
mse_sum += mse(ls_estimate, ideal) |
|
mse_sum /= folder_size |
|
mse_sums[int(val)] = mse_sum |
|
return mse_sums |
|
|
|
def get_model_details(model): |
|
""" |
|
Get parameter counts and structure details for a PyTorch model. |
|
|
|
Analyzes a PyTorch model to determine the total number of trainable |
|
parameters and creates a formatted table showing the parameter count |
|
for each named parameter in the model. |
|
|
|
Args: |
|
model: PyTorch model to analyze |
|
|
|
Returns: |
|
tuple containing: |
|
- total_params: Total number of trainable parameters |
|
- table: PrettyTable showing parameter counts by module |
|
""" |
|
|
|
table = PrettyTable(["Modules", "Parameters"]) |
|
total_params = 0 |
|
for name, parameter in model.named_parameters(): |
|
if not parameter.requires_grad: |
|
continue |
|
params = parameter.numel() |
|
table.add_row([name, params]) |
|
total_params += params |
|
return total_params, table |
|
|