| from __future__ import annotations |
|
|
| import dataclasses |
| import math |
| import pickle |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
| from .angle_delay import AngleDelayConfig, AngleDelayProcessor |
| from ..models.lwm import ComplexPatchTokenizer |
|
|
|
|
| @dataclasses.dataclass |
| class AngleDelayDatasetConfig: |
| raw_path: Path |
| keep_percentage: float = 0.25 |
| normalize: str = "global_rms" |
| cache_dir: Optional[Path] = Path("cache") |
| use_cache: bool = True |
| overwrite_cache: bool = False |
| snr_db: Optional[float] = None |
| noise_seed: Optional[int] = None |
| max_time_steps: Optional[int] = None |
| patch_size: Tuple[int, int] = (1, 1) |
| phase_mode: str = "real_imag" |
|
|
|
|
| class AngleDelaySequenceDataset(Dataset): |
| """Angle-delay dataset that tokenizes sequences and caches the processed tensors.""" |
|
|
| def __init__(self, config: AngleDelayDatasetConfig, logger: Optional[Any] = None) -> None: |
| super().__init__() |
| self.config = config |
| self.logger = logger |
| self.tokenizer = ComplexPatchTokenizer(config.phase_mode) |
| self.samples: List[torch.Tensor] |
| cache_hit = False |
| cache_path = self._cache_path() if config.use_cache and config.cache_dir is not None else None |
| if cache_path and cache_path.exists() and not config.overwrite_cache: |
| try: |
| payload = torch.load(cache_path, map_location="cpu") |
| if isinstance(payload, dict) and "samples" in payload: |
| self.samples = payload["samples"] |
| else: |
| self.samples = payload |
| cache_hit = True |
| except Exception: |
| cache_path.unlink(missing_ok=True) |
| cache_hit = False |
| if not cache_hit: |
| self.samples = self._build_samples() |
| if cache_path is not None: |
| cache_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save({"samples": self.samples}, cache_path) |
| if self.config.snr_db is not None: |
| self._apply_noise() |
|
|
| def _cache_path(self) -> Path: |
| cfg = self.config |
| name = cfg.raw_path.stem |
| |
| |
| ph, pw = cfg.patch_size |
| cache_name = f"adseq_{name}_keep{int(cfg.keep_percentage * 100)}_{cfg.normalize}_p{ph}x{pw}_{cfg.phase_mode}_v2.pt" |
| return cfg.cache_dir / cache_name |
|
|
| def _load_raw(self) -> Any: |
| with self.config.raw_path.open("rb") as handle: |
| return pickle.load(handle) |
|
|
| def _normalize_sample(self, tensor: torch.Tensor) -> torch.Tensor: |
| """Normalize a single sample by its own RMS.""" |
| rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8) |
| return tensor / rms.to(tensor.dtype) |
|
|
| def _build_samples(self) -> List[torch.Tensor]: |
| payload = self._load_raw() |
| channel = payload["channel"] if isinstance(payload, dict) and "channel" in payload else payload |
| channel_tensor = torch.as_tensor(channel, dtype=torch.complex64) |
| if channel_tensor.ndim == 3: |
| channel_tensor = channel_tensor.unsqueeze(0) |
| if self.config.max_time_steps is not None and channel_tensor.size(1) > self.config.max_time_steps: |
| channel_tensor = channel_tensor[:, : self.config.max_time_steps] |
| processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=self.config.keep_percentage)) |
| samples: List[torch.Tensor] = [] |
| for seq in channel_tensor: |
| ad = processor.forward(seq) |
| truncated, _ = processor.truncate_delay_bins(ad) |
| samples.append(truncated) |
| |
| |
| if self.config.normalize == "per_sample_rms": |
| samples = [self._normalize_sample(s) for s in samples] |
| elif self.config.normalize == "global_rms": |
| |
| total_sum_sq = 0.0 |
| total_count = 0 |
| for s in samples: |
| s_real = s.real.float() |
| s_imag = s.imag.float() |
| total_sum_sq += (s_real ** 2 + s_imag ** 2).sum().item() |
| total_count += s_real.numel() |
| if total_count > 0: |
| global_rms = math.sqrt(total_sum_sq / total_count) |
| global_rms = max(global_rms, 1e-8) |
| samples = [s / torch.tensor(global_rms, dtype=torch.float32).to(s.dtype) for s in samples] |
| |
| return samples |
|
|
| def _apply_noise(self) -> None: |
| if self.config.noise_seed is not None: |
| torch.manual_seed(int(self.config.noise_seed)) |
| noisy: List[torch.Tensor] = [] |
| snr_lin = 10.0 ** (float(self.config.snr_db) / 10.0) |
| for sample in self.samples: |
| real = sample.real.float() |
| imag = sample.imag.float() |
| power = (real.square() + imag.square()).mean().item() |
| if power <= 0: |
| noisy.append(sample) |
| continue |
| noise_var = power / snr_lin |
| std = math.sqrt(noise_var / 2.0) |
| noise_real = torch.randn_like(real) * std |
| noise_imag = torch.randn_like(imag) * std |
| noise = torch.complex(noise_real.to(sample.dtype), noise_imag.to(sample.dtype)) |
| noisy.append((sample + noise).to(sample.dtype)) |
| self.samples = noisy |
|
|
| def __len__(self) -> int: |
| return len(self.samples) |
|
|
| def __getitem__(self, index: int) -> Dict[str, Any]: |
| sample = self.samples[index] |
| tokens, base_mask = self.tokenizer(sample.unsqueeze(0), self.config.patch_size) |
| tokens = tokens.squeeze(0) |
| base_mask = base_mask.squeeze(0) |
| T, N, M = sample.shape |
| ph, pw = self.config.patch_size |
| H = N // ph |
| W = M // pw |
| shape = torch.tensor([T, H, W], dtype=torch.long) |
| payload: Dict[str, Any] = { |
| "sequence": sample, |
| "tokens": tokens, |
| "base_mask": base_mask, |
| "shape": shape, |
| } |
| return payload |
|
|
|
|
| def load_adseq_dataset( |
| data_path: str | Path, |
| keep_percentage: float = 0.25, |
| normalize: str = "global_rms", |
| cache_dir: Optional[str | Path] = "cache", |
| use_cache: bool = True, |
| overwrite_cache: bool = False, |
| logger: Optional[Any] = None, |
| snr_db: Optional[float] = None, |
| noise_seed: Optional[int] = None, |
| max_time_steps: Optional[int] = None, |
| ) -> "AngleDelaySequenceDataset": |
| cfg = AngleDelayDatasetConfig( |
| raw_path=Path(data_path), |
| keep_percentage=keep_percentage, |
| normalize=normalize, |
| cache_dir=None if cache_dir is None else Path(cache_dir), |
| use_cache=use_cache, |
| overwrite_cache=overwrite_cache, |
| snr_db=snr_db, |
| noise_seed=noise_seed, |
| max_time_steps=max_time_steps, |
| ) |
| return AngleDelaySequenceDataset(cfg, logger=logger) |
|
|