AdaFortiTran / src /utils.py
BerkIGuler's picture
new README.md and Trainer design
4fa78a3
"""
Utility functions for OFDM channel estimation.
This module provides various utility functions for processing, visualizing,
and analyzing OFDM channel estimation data, including complex channel matrices,
error calculations, model statistics, and visualization tools for
performance evaluation.
"""
from pathlib import Path
from typing import Optional, Union
import re
import os
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from prettytable import PrettyTable
import torch
class EarlyStopping:
"""Handles early stopping logic for training.
Monitors validation loss during training and signals when to stop
training if the loss has not improved for a specified number of epochs.
Attributes:
patience: Number of epochs to wait before stopping training
remaining_patience: Current remaining patience counter
min_loss: Minimum validation loss observed so far
"""
def __init__(self, patience: int = 3):
"""
Initialize early stopping.
Args:
patience: Number of epochs to wait before stopping
"""
self.patience = patience
self.remaining_patience = patience
self.min_loss: Optional[float] = None
def early_stop(self, loss: float) -> bool:
"""
Check if training should stop.
Args:
loss: Current validation loss
Returns:
Whether to stop training
"""
if self.min_loss is None:
self.min_loss = loss
return False
if loss < self.min_loss:
self.min_loss = loss
self.remaining_patience = self.patience
return False
self.remaining_patience -= 1
return self.remaining_patience == 0
def extract_values(file_name):
"""
Extract channel information from a file name.
Parses file names with format:
'{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat'
Example:
For filename "1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat":
- file_number: 1
- snr: 20 (Signal-to-Noise Ratio in dB)
- delay_spread: 50 (Delay Spread)
- doppler: 500 (Maximum Doppler Shift)
- pilot_freq: 3 (Pilot placement frequency)
- channel_type: TDL-A (Channel model type)
Args:
file_name: The file name from which to extract channel information
Returns:
tuple: A tuple containing:
- file_number (torch.Tensor): The file number (sequential identifier)
- snr (torch.Tensor): Signal-to-noise ratio value in dB
- delay_spread (torch.Tensor): Delay spread value
- max_doppler_shift (torch.Tensor): Maximum Doppler shift value
- pilot_placement_frequency (torch.Tensor): Pilot placement frequency
- channel_type (list): The channel type (e.g., ['TDL-A'])
Raises:
ValueError: If the file name does not match the expected pattern
"""
pattern = r'(\d+)_SNR-(\d+)_DS-(\d+)_DOP-(\d+)_N-(\d+)_([A-Z\-]+)\.mat'
match = re.match(pattern, file_name)
if match:
file_no = torch.tensor([int(match.group(1))], dtype=torch.float)
snr_value = torch.tensor([int(match.group(2))], dtype=torch.float)
ds_value = torch.tensor([int(match.group(3))], dtype=torch.float)
dop_value = torch.tensor([int(match.group(4))], dtype=torch.float)
n = torch.tensor([int(match.group(5))], dtype=torch.float)
channel_type = [match.group(6)]
return file_no, snr_value, ds_value, dop_value, n, channel_type
else:
raise ValueError("Cannot extract file information.")
def get_error_images(variable, channel_data, show=False):
"""
Create visualizations of channel estimation errors.
Generates a figure with error heatmaps for different channel conditions,
showing the absolute difference between estimated and ideal channels.
Args:
variable: Name of the variable being visualized (e.g., 'SNR', 'DS')
channel_data: Dictionary mapping parameter values to dictionaries
containing 'estimated_channel' and 'ideal_channel'
show: Whether to display the figure immediately (default: False)
Returns:
matplotlib.figure.Figure: The generated figure with error heatmaps
"""
# Create a figure with 7 subplots
fig, axes = plt.subplots(1, len(channel_data), figsize=(20, 6))
# Plot each subplot with consistent color scaling
for i, (key, channels) in enumerate(channel_data.items()):
# Calculate absolute error between estimated and ideal channels
estimated_channel = channels['estimated_channel']
ideal_channel = channels['ideal_channel']
error_matrix = torch.abs(estimated_channel - ideal_channel)
error_numpy = error_matrix.detach().cpu().numpy()
# Plot in the corresponding subplot with shared colormap limits
ax = axes[i]
cax = ax.imshow(error_numpy, cmap='viridis', aspect=14 / 120, vmin=0, vmax=1)
ax.set_title(f"{variable} = {key}")
ax.set_xlabel('Columns (14)')
ax.set_ylabel('Rows (120)')
# Create a new axis for the color bar to the right of the subplots
cbar_ax = fig.add_axes((0.92, 0.15, 0.02, 0.7)) # [left, bottom, width, height]
fig.colorbar(cax, cax=cbar_ax, label='Error Magnitude')
# Adjust layout to prevent overlapping labels
fig.tight_layout(rect=(0, 0, 0.9, 1)) # Leave space for the color bar on the right
# Show the figure if `show` is True
if show:
plt.show()
# Return the main figure
return fig
def concat_complex_channel(channel_matrix):
"""
Convert a complex channel matrix into a real matrix by concatenating real and imaginary parts.
Transforms a complex tensor into a real-valued tensor by concatenating
the real and imaginary components along the specified dimension.
Args:
channel_matrix: Complex channel matrix
Returns:
Real-valued channel matrix with concatenated real and imaginary parts
"""
real_channel_m = torch.real(channel_matrix)
imag_channel_m = torch.imag(channel_matrix)
cat_channel_m = torch.cat((real_channel_m, imag_channel_m), dim=1)
return cat_channel_m
def get_test_stats_plot(x_name, stats, methods, show=False):
"""
Plot test statistics for multiple methods as line graphs.
Creates a line plot comparing performance metrics (e.g., MSE) across
different conditions or parameters for multiple methods.
Args:
x_name: Label for the x-axis (e.g., 'SNR', 'DS', 'Epoch')
stats: List of dictionaries where each dictionary maps x-values to
performance metrics for a specific method
methods: List of method names corresponding to each entry in stats
show: Whether to display the plot immediately (default: False)
Returns:
matplotlib.figure.Figure: The generated figure object
Raises:
AssertionError: If stats and methods lists have different lengths
"""
assert len(stats) == len(methods), "Provided stats and methods do not have the same length."
fig = plt.figure()
symbols = iter(["*", "x", "+", "D", "v", "^"])
for stat in stats:
try:
symbol = next(symbols)
except StopIteration:
symbols = iter(["o", "*", "x", "+", "D", "v", "^"])
symbol = next(symbols)
kv_pairs = sorted(list(stat.items()), key=lambda x: x[0])
x_vals = []
y_vals = []
for key, value in kv_pairs:
x_vals.append(key)
y_vals.append(value)
plt.plot(x_vals, y_vals, f"{symbol}--")
plt.xlabel(x_name)
plt.ylabel("MSE Error (dB)")
plt.grid()
plt.legend(methods)
if show:
plt.show()
return fig
def to_db(val):
"""
Convert values to decibels (dB).
Applies the formula 10 * log10(val) to convert values to the decibel scale.
Args:
val: Input value or array to convert to dB (must be positive)
Returns:
The input value(s) converted to decibels
"""
return 10 * np.log10(val)
def mse(x, y):
"""
Calculate mean squared error (MSE) in dB between two complex arrays.
Computes the average squared magnitude of the difference between
two complex arrays and converts the result to decibels.
Args:
x: First complex numpy array
y: Second complex numpy array (same shape as x)
Returns:
MSE in decibels (dB) between the two arrays
"""
mse_xy = np.mean(np.square(np.abs(x - y)))
mse_xy_db = to_db(mse_xy)
return mse_xy_db
def get_ls_mse_per_folder(folders_dir: Union[Path, str]):
"""
Calculate average MSE for LS estimates in each subfolder.
For each subfolder in the specified directory, calculates the average
mean squared error between least-squares channel estimates and ideal
channel values across all .mat files in that subfolder.
Args:
folders_dir: Path to directory containing subfolders with .mat files
Each subfolder should be named 'prefix_val' where val is an integer
Returns:
Dictionary mapping integer values from subfolder names to average MSE values in dB
Notes:
- Each .mat file should contain a 3D matrix 'H' where:
- H[:,:,0] is the ideal channel
- H[:,:,2] is the LS channel estimate
- Subfolders are sorted by the integer in their names
"""
mse_sums = {}
folders = os.listdir(folders_dir)
folders = sorted(folders, key=lambda x: int(x.split("_")[1]))
for folder in folders:
_, val = folder.split("_")
mse_sum = 0
folder_size = len(os.listdir(os.path.join(folders_dir, folder)))
for file in os.listdir(os.path.join(folders_dir, folder)):
mat_data = sio.loadmat(os.path.join(folders_dir, folder, file))['H']
ls_estimate = mat_data[:, :, 2]
ideal = mat_data[:, :, 0]
mse_sum += mse(ls_estimate, ideal)
mse_sum /= folder_size
mse_sums[int(val)] = mse_sum
return mse_sums
def get_model_details(model):
"""
Get parameter counts and structure details for a PyTorch model.
Analyzes a PyTorch model to determine the total number of trainable
parameters and creates a formatted table showing the parameter count
for each named parameter in the model.
Args:
model: PyTorch model to analyze
Returns:
tuple containing:
- total_params: Total number of trainable parameters
- table: PrettyTable showing parameter counts by module
"""
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
params = parameter.numel()
table.add_row([name, params])
total_params += params
return total_params, table