| from __future__ import annotations |
|
|
| import dataclasses |
| import math |
| from pathlib import Path |
| from typing import Iterable, List, Optional, Sequence |
|
|
| import imageio.v2 as imageio |
| import matplotlib.animation as animation |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
|
|
|
|
| @dataclasses.dataclass |
| class AngleDelayConfig: |
| """Configuration options for angle-delay processing.""" |
|
|
| angle_range: tuple[float, float] = (-math.pi / 2, math.pi / 2) |
| delay_range: tuple[float, float] = (0.0, 100.0) |
| keep_percentage: float = 0.25 |
| fps: int = 4 |
| dpi: int = 120 |
| num_bins: int = 6 |
| output_dir: Path = Path("figs") |
|
|
| def validate(self) -> None: |
| if not 0.0 < self.keep_percentage <= 1.0: |
| raise ValueError("keep_percentage must be in (0, 1]") |
| if self.fps <= 0: |
| raise ValueError("fps must be positive") |
| if self.dpi <= 0: |
| raise ValueError("dpi must be positive") |
| if self.num_bins <= 0: |
| raise ValueError("num_bins must be positive") |
|
|
|
|
| class AngleDelayProcessor: |
| """Project complex channels into the angle-delay domain and visualise them.""" |
|
|
| def __init__(self, config: AngleDelayConfig | None = None) -> None: |
| self.config = config or AngleDelayConfig() |
| self.config.validate() |
|
|
| |
| |
| |
| @staticmethod |
| def _ensure_complex(tensor: torch.Tensor) -> torch.Tensor: |
| if not torch.is_complex(tensor): |
| raise TypeError("expected complex tensor") |
| return tensor |
|
|
| def forward(self, channel: torch.Tensor) -> torch.Tensor: |
| channel = self._ensure_complex(channel) |
| angle_domain = torch.fft.fft(channel, dim=1, norm="ortho") |
| delay_domain = torch.fft.ifft(angle_domain, dim=2, norm="ortho") |
| return delay_domain |
|
|
| def inverse(self, angle_delay: torch.Tensor) -> torch.Tensor: |
| angle_delay = self._ensure_complex(angle_delay) |
| subcarrier = torch.fft.fft(angle_delay, dim=2, norm="ortho") |
| antenna = torch.fft.ifft(subcarrier, dim=1, norm="ortho") |
| return antenna |
|
|
| |
| |
| |
| def truncate_delay_bins(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| tensor = self._ensure_complex(tensor) |
| if tensor.ndim != 3: |
| raise ValueError("angle-delay tensor must have shape (T, N, M)") |
| keep = max(1, int(round(tensor.size(-1) * self.config.keep_percentage))) |
| truncated = tensor[..., :keep] |
| padded = torch.zeros_like(tensor) |
| padded[..., :keep] = truncated |
| return truncated, padded |
|
|
| @staticmethod |
| def nmse(reference: torch.Tensor, reconstruction: torch.Tensor) -> float: |
| reference = AngleDelayProcessor._ensure_complex(reference) |
| reconstruction = AngleDelayProcessor._ensure_complex(reconstruction) |
| mse = torch.mean(torch.abs(reference - reconstruction) ** 2) |
| power = torch.mean(torch.abs(reference) ** 2).clamp_min(1e-12) |
| return float(10.0 * torch.log10(mse / power)) |
|
|
| def reconstruction_nmse(self, channel: torch.Tensor) -> tuple[float, float]: |
| ad_full = self.forward(channel) |
| recon_full = self.inverse(ad_full) |
| nmse_full = self.nmse(channel, recon_full) |
| truncated, padded = self.truncate_delay_bins(ad_full) |
| recon_trunc = self.inverse(padded) |
| nmse_trunc = self.nmse(channel, recon_trunc) |
| return nmse_full, nmse_trunc |
|
|
| |
| |
| |
| def save_angle_delay_gif( |
| self, |
| tensor: torch.Tensor, |
| output_path: Path, |
| fps: Optional[int] = None, |
| show: bool = False, |
| ) -> None: |
| tensor = self._ensure_complex(tensor) |
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| magnitude = tensor.abs().cpu() |
| vmin, vmax = float(magnitude.min()), float(magnitude.max()) |
| if show: |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| fig.patch.set_facecolor("#0b0e11") |
| ax.set_facecolor("#0b0e11") |
| ax.tick_params(colors="#cbd5f5") |
| for spine in ax.spines.values(): |
| spine.set_color("#374151") |
| im = ax.imshow( |
| magnitude[0].numpy(), |
| cmap="magma", |
| origin="lower", |
| aspect="auto", |
| extent=[*self.config.delay_range, *self.config.angle_range], |
| vmin=vmin, |
| vmax=vmax, |
| ) |
| ax.set_xlabel("Delay bins", color="#cbd5f5") |
| ax.set_ylabel("Angle bins", color="#cbd5f5") |
| cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
| cbar.ax.yaxis.set_tick_params(color="#cbd5f5") |
| plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5") |
| cbar.set_label("|H| (dB)", color="#cbd5f5") |
|
|
| def animate(idx: int): |
| im.set_array(magnitude[idx].numpy()) |
| ax.set_title( |
| f"Angle-Delay Intensity — Frame {idx}", |
| color="#f8fafc", |
| fontsize=12, |
| fontweight="semibold", |
| ) |
| return (im,) |
|
|
| |
| |
| self._save_animation(fig, animate, output_path, fps=fps, frames=magnitude.size(0), show=True) |
| return |
|
|
| |
| frames: List[np.ndarray] = [] |
| for frame_idx in range(magnitude.size(0)): |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| fig.patch.set_facecolor("#0b0e11") |
| ax.set_facecolor("#0b0e11") |
| ax.tick_params(colors="#cbd5f5") |
| for spine in ax.spines.values(): |
| spine.set_color("#374151") |
| im = ax.imshow( |
| magnitude[frame_idx].numpy(), |
| cmap="magma", |
| origin="lower", |
| aspect="auto", |
| extent=[*self.config.delay_range, *self.config.angle_range], |
| vmin=vmin, |
| vmax=vmax, |
| ) |
| ax.set_xlabel("Delay bins", color="#cbd5f5") |
| ax.set_ylabel("Angle bins", color="#cbd5f5") |
| ax.set_title( |
| f"Angle-Delay Intensity — Frame {frame_idx}", |
| color="#f8fafc", |
| fontsize=12, |
| fontweight="semibold", |
| ) |
| cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
| cbar.ax.yaxis.set_tick_params(color="#cbd5f5") |
| plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5") |
| cbar.set_label("|H| (dB)", color="#cbd5f5") |
| fig.canvas.draw() |
| frames.append(np.asarray(fig.canvas.buffer_rgba())) |
| plt.close(fig) |
|
|
| imageio.mimsave(output_path, frames, fps=fps or self.config.fps) |
|
|
| def _save_animation( |
| self, |
| fig: plt.Figure, |
| animate_fn, |
| output_path: Path, |
| fps: Optional[int] = None, |
| dpi: Optional[int] = None, |
| frames: Optional[int] = None, |
| show: bool = False, |
| ) -> None: |
| anim = animation.FuncAnimation(fig, animate_fn, frames=frames) |
| if show: |
| from IPython.display import HTML, display |
|
|
| html = anim.to_jshtml(fps=fps or self.config.fps) |
| plt.close(fig) |
| display(HTML(html)) |
| else: |
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| anim.save(output_path, writer="pillow", fps=fps or self.config.fps, dpi=dpi or self.config.dpi) |
| plt.close(fig) |
|
|
| def save_channel_animation(self, channel: torch.Tensor, output_path: Path, show: bool = False) -> None: |
| channel = self._ensure_complex(channel) |
| magnitude = channel.abs().cpu() |
| vmin, vmax = float(magnitude.min()), float(magnitude.max()) |
|
|
| fig, ax_mag = plt.subplots(figsize=(8, 6)) |
| fig.patch.set_facecolor("#0b0e11") |
| ax_mag.set_facecolor("#0b0e11") |
| ax_mag.tick_params(colors="#cbd5f5") |
| for spine in ax_mag.spines.values(): |
| spine.set_color("#374151") |
| mag_img = ax_mag.imshow( |
| magnitude[0].numpy(), |
| cmap="magma", |
| origin="upper", |
| aspect="auto", |
| vmin=vmin, |
| vmax=vmax, |
| ) |
| ax_mag.set_xlabel("Subcarrier", color="#cbd5f5") |
| ax_mag.set_ylabel("Antenna", color="#cbd5f5") |
| cbar = fig.colorbar(mag_img, ax=ax_mag, fraction=0.046, pad=0.04) |
| cbar.ax.yaxis.set_tick_params(color="#cbd5f5") |
| plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5") |
| cbar.set_label("|H| (linear)", color="#cbd5f5") |
|
|
| def animate(idx: int): |
| mag_img.set_array(magnitude[idx].numpy()) |
| ax_mag.set_title( |
| f"Channel Magnitude — Frame {idx}", |
| color="#f8fafc", |
| fontsize=12, |
| fontweight="semibold", |
| ) |
| return (mag_img,) |
|
|
| self._save_animation(fig, animate, output_path, frames=channel.size(0), show=show) |
|
|
| def save_angle_delay_animation( |
| self, |
| tensor: torch.Tensor, |
| output_path: Path, |
| keep_percentage: Optional[float] = None, |
| show: bool = False, |
| ) -> None: |
| tensor = self._ensure_complex(tensor) |
| magnitude = tensor.abs().cpu() |
| phase = torch.angle(tensor).cpu() |
| keep_suffix = "" if keep_percentage is None else f" (keep={keep_percentage * 100:.0f}%)" |
|
|
| fig, axes = plt.subplots(2, 2, figsize=(18, 10)) |
| mag_ax, phase_ax, mag_line_ax, phase_line_ax = axes.flat |
| mag_img = mag_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto") |
| mag_ax.set_xlabel("Delay Bin") |
| mag_ax.set_ylabel("Angle Bin") |
| fig.colorbar(mag_img, ax=mag_ax, label="Magnitude") |
|
|
| phase_img = phase_ax.imshow(phase[0].numpy(), cmap="twilight", origin="upper", aspect="auto", vmin=-math.pi, vmax=math.pi) |
| phase_ax.set_xlabel("Delay Bin") |
| phase_ax.set_ylabel("Angle Bin") |
| fig.colorbar(phase_img, ax=phase_ax, label="Phase (rad)") |
|
|
| temporal_mag = magnitude.mean(dim=(1, 2)) |
| temporal_phase = np.unwrap(phase.mean(dim=(1, 2)).numpy()) |
| mag_line, = mag_line_ax.plot([], [], "r-o", linewidth=2) |
| phase_line, = phase_line_ax.plot([], [], "b-s", linewidth=2) |
|
|
| for axis, label in ((mag_line_ax, "Average Magnitude"), (phase_line_ax, "Average Phase (rad)")): |
| axis.set_xlabel("Frame") |
| axis.set_ylabel(label) |
| axis.set_xlim(0, tensor.size(0) - 1) |
| axis.grid(True, alpha=0.3) |
|
|
| def animate(idx: int): |
| mag_img.set_array(magnitude[idx].numpy()) |
| phase_img.set_array(phase[idx].numpy()) |
| mag_ax.set_title(f"AD Magnitude – Frame {idx}{keep_suffix}") |
| phase_ax.set_title(f"AD Phase – Frame {idx}{keep_suffix}") |
| xs = np.arange(idx + 1) |
| mag_line.set_data(xs, temporal_mag[: idx + 1].numpy()) |
| phase_line.set_data(xs, temporal_phase[: idx + 1]) |
| return mag_img, phase_img, mag_line, phase_line |
|
|
| self._save_animation(fig, animate, output_path, show=show) |
|
|
| def save_dominant_bin_animation( |
| self, |
| tensor: torch.Tensor, |
| output_path: Path, |
| threshold_ratio: float = 0.05, |
| show: bool = False, |
| ) -> None: |
| tensor = self._ensure_complex(tensor) |
| magnitude = tensor.abs().cpu() |
| threshold = float(magnitude.max()) * threshold_ratio |
| dominant_counts = (magnitude > threshold).sum(dim=(1, 2)).numpy() |
|
|
| fig, (heat_ax, line_ax) = plt.subplots(1, 2, figsize=(16, 6)) |
| heat_img = heat_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto") |
| heat_ax.set_xlabel("Delay Bin") |
| heat_ax.set_ylabel("Angle Bin") |
| fig.colorbar(heat_img, ax=heat_ax, label="Magnitude") |
|
|
| count_line, = line_ax.plot([], [], "r-s", linewidth=2) |
| line_ax.set_xlabel("Frame") |
| line_ax.set_ylabel("Dominant Bin Count") |
| line_ax.set_xlim(0, tensor.size(0) - 1) |
| line_ax.set_ylim(0, dominant_counts.max() * 1.1) |
| line_ax.grid(True, alpha=0.3) |
|
|
| def animate(idx: int): |
| heat_img.set_array(magnitude[idx].numpy()) |
| heat_ax.set_title(f"Magnitude – Frame {idx}") |
| xs = np.arange(idx + 1) |
| count_line.set_data(xs, dominant_counts[: idx + 1]) |
| return heat_img, count_line |
|
|
| self._save_animation(fig, animate, output_path, show=show) |
|
|
| def save_bin_evolution_plot(self, tensor: torch.Tensor, output_path: Path, show: bool = False) -> None: |
| tensor = self._ensure_complex(tensor) |
| magnitude = tensor.abs() |
| avg_mag = magnitude.mean(dim=0) |
| flat_mag = avg_mag.flatten() |
| |
| |
| k = min(3, flat_mag.numel()) |
| if k == 0: |
| return |
| _, indices = torch.topk(flat_mag, k) |
| angle_indices = (indices // tensor.size(-1)).tolist() |
| delay_indices = (indices % tensor.size(-1)).tolist() |
|
|
| time_axis = np.arange(tensor.size(0)) |
| fig, axes = plt.subplots( |
| k, |
| 2, |
| figsize=(11, 3 * max(1, k)), |
| dpi=150, |
| constrained_layout=True, |
| ) |
| fig.patch.set_facecolor("#0b0e11") |
| axes = np.atleast_2d(axes) |
| label_color = "#cbd5f5" |
| title_color = "#f8fafc" |
|
|
| for row in range(k): |
| series = tensor[:, angle_indices[row], delay_indices[row]] |
| mag_series = torch.abs(series).cpu().numpy() |
| phase_series = np.unwrap(torch.angle(series).cpu().numpy()) |
|
|
| ax_mag, ax_phase = axes[row] |
|
|
| |
| ax_mag.set_facecolor("#111827") |
| ax_mag.plot( |
| time_axis, |
| mag_series, |
| label="|H|", |
| color="#38bdf8", |
| linewidth=2.2, |
| ) |
| ax_mag.fill_between(time_axis, mag_series, color="#38bdf8", alpha=0.08) |
| ax_mag.set_title( |
| f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) magnitude", |
| color=title_color, |
| ) |
| ax_mag.set_xlabel("time index", color=label_color) |
| ax_mag.set_ylabel("|H|", color=label_color) |
| ax_mag.tick_params(colors=label_color) |
| ax_mag.grid(True, linestyle="--", linewidth=0.6, alpha=0.4) |
| for spine in ax_mag.spines.values(): |
| spine.set_color("#1f2937") |
| legend_mag = ax_mag.legend(loc="upper left", fontsize=9) |
| legend_mag.get_frame().set_facecolor("#111827") |
| legend_mag.get_frame().set_alpha(0.6) |
| for text in legend_mag.get_texts(): |
| text.set_color(label_color) |
|
|
| |
| ax_phase.set_facecolor("#111827") |
| ax_phase.plot( |
| time_axis, |
| phase_series, |
| label="∠H", |
| color="#f87171", |
| linewidth=2.2, |
| ) |
| ax_phase.set_title( |
| f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) phase (unwrapped)", |
| color=title_color, |
| ) |
| ax_phase.set_xlabel("time index", color=label_color) |
| ax_phase.set_ylabel("radians", color=label_color) |
| ax_phase.tick_params(colors=label_color) |
| ax_phase.grid(True, linestyle="--", linewidth=0.6, alpha=0.4) |
| for spine in ax_phase.spines.values(): |
| spine.set_color("#1f2937") |
| legend_phase = ax_phase.legend(loc="upper left", fontsize=9) |
| legend_phase.get_frame().set_facecolor("#111827") |
| legend_phase.get_frame().set_alpha(0.6) |
| for text in legend_phase.get_texts(): |
| text.set_color(label_color) |
|
|
| fig.suptitle("Top-3 angle–delay bins over time", fontsize=12, color=title_color) |
| if show: |
| plt.show() |
| else: |
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| fig.savefig(output_path, dpi=self.config.dpi, bbox_inches="tight") |
| plt.close(fig) |
|
|