LWM-Spectro / utils.py
wi-lab's picture
Upload folder using huggingface_hub
6be66f6
import torch, torch.nn as nn, numpy as np, os, pickle, platform
import torch.distributed as dist
from typing import Optional, Dict, Any
from numpy.random import Generator, default_rng
try:
from tqdm import tqdm # type: ignore
except ImportError: # pragma: no cover - optional dependency
def tqdm(iterable, *args, **kwargs):
return iterable
# Optional deps for MATLAB .mat (v7.3 HDF5) loading
try:
import h5py # type: ignore
except Exception:
h5py = None # Fallback handled below
try:
from scipy.io import loadmat # type: ignore
except Exception:
loadmat = None # Only used if available
from collections import defaultdict
from torch.utils.data import TensorDataset, DataLoader
# Use tqdm for better progress display
USE_TQDM = True
def count_parameters(model, log: bool = True):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
if log:
print(f"πŸ“Š Model: {total:,} total, {trainable:,} trainable")
return total
def generate_spectrograms_and_labels(scenario_name, spectrogram_path, cache_path):
# TEMP FIX: Skip cache if cache_path is None
if cache_path and os.path.exists(cache_path):
with open(cache_path, 'rb') as f:
cached_data = pickle.load(f)
# Handle different cache formats
if isinstance(cached_data, dict) and 'samples' in cached_data:
spectrograms = cached_data['samples']
else:
spectrograms = cached_data
else:
# Load data directly if cache doesn't exist or cache_path is None
spectrograms = load_spectrogram_data(spectrogram_path)
# Create cache file (only if cache_path is provided)
if cache_path:
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
with open(cache_path, 'wb') as f:
pickle.dump(spectrograms, f)
labels = torch.zeros(len(spectrograms), dtype=torch.long)
# Convert list of tensors to single tensor if needed
if isinstance(spectrograms, list):
spectrograms = torch.stack(spectrograms)
return spectrograms, labels
def load_spectrogram_data(path):
"""Load spectrogram data from a .pkl, .mat file, or directory.
Returns a numpy array with shape:
- (N, rows, cols) for single-channel spectrograms
- (N, C, rows, cols) for multi-channel spectrograms
"""
specs = []
def _load_from_pkl(file_path):
with open(file_path, 'rb') as f:
data = pickle.load(f)
if isinstance(data, dict) and 'spectrograms' in data:
arr = data['spectrograms']
if isinstance(arr, np.ndarray):
return arr
if isinstance(data, np.ndarray):
return data
return None
def _load_from_mat(file_path):
# Primary path: MATLAB v7.3 (HDF5) via h5py
if h5py is not None:
try:
with h5py.File(file_path, 'r') as f:
# Prefer 'spectrograms'; otherwise pick the largest numeric dataset
if 'spectrograms' in f:
ds = f['spectrograms']
else:
cand = []
def _collect(name, obj):
try:
if isinstance(obj, h5py.Dataset) and obj.dtype.kind in ('f','i','u','c','V'):
cand.append((name, obj))
except Exception:
pass
f.visititems(_collect)
if not cand:
return None
# pick the dataset with the most elements
name, ds = max(cand, key=lambda kv: np.prod(kv[1].shape) if hasattr(kv[1], 'shape') else 0)
# Complex handling: structured dtype with fields 'real'/'imag' or native complex dtype
if hasattr(ds.dtype, 'fields') and ds.dtype.fields and 'real' in ds.dtype.fields and 'imag' in ds.dtype.fields:
real = ds['real'][...]
imag = ds['imag'][...]
arr = real + 1j * imag
else:
arr = ds[...]
return np.array(arr)
except Exception:
# Fallback to scipy if available
pass
# Fallback path: older MATLAB formats via scipy.io.loadmat
if loadmat is not None:
try:
data = loadmat(file_path)
# Prefer exact key; else choose first suitable numeric array
if 'spectrograms' in data:
arr = data['spectrograms']
return np.array(arr)
for k, v in data.items():
if k.startswith('__'):
continue
if isinstance(v, np.ndarray) and v.ndim >= 2 and v.size > 0 and np.issubdtype(v.dtype, np.number):
return np.array(v)
except Exception:
pass
return None
def _normalize_shape(arr: np.ndarray) -> np.ndarray:
"""Normalize array to (N, rows, cols) or (N, C, rows, cols).
Handles both MATLAB-saved HDF5 layouts and already-normalized tensors:
- (rows, cols) -> (1, rows, cols)
- (rows, cols, N) -> (N, rows, cols)
- (N, rows, cols) -> (N, rows, cols)
- (rows, cols, C, N) -> (N, C, rows, cols)
- (N, C, rows, cols) -> (N, C, rows, cols)
"""
if arr.ndim == 2:
return arr[None, ...]
if arr.ndim == 3:
# Heuristic: if last dim looks like N, transpose; else assume already (N, rows, cols)
if arr.shape[2] > 4 and arr.shape[0] <= 512 and arr.shape[1] <= 512:
return np.transpose(arr, (2, 0, 1))
else:
return arr
if arr.ndim == 4:
# Two common patterns: (rows, cols, C, N) or (N, C, rows, cols)
# Detect by which axis likely holds N (#samples)
# If first axis is large and second is small (#channels), likely already (N, C, rows, cols)
if arr.shape[0] > 4 and arr.shape[1] in (1, 2, 4, 8, 16, 32):
return arr
# Else if last axis is large (N) and third axis is small (C), transpose
if arr.shape[3] > 4 and arr.shape[2] in (1, 2, 4, 8, 16, 32):
return np.transpose(arr, (3, 2, 0, 1))
# Fallback to original assumption
return np.transpose(arr, (3, 2, 0, 1))
return arr
# File path
if os.path.isfile(path):
if path.endswith('.pkl'):
arr = _load_from_pkl(path)
if arr is not None:
arr = _normalize_shape(arr)
return arr
if path.endswith('.mat'):
arr = _load_from_mat(path)
if arr is not None:
arr = _normalize_shape(arr)
return arr
return np.array([])
# Directory path
for root, _, files in os.walk(path):
for f in files:
file_path = os.path.join(root, f)
if f.endswith('.pkl'):
arr = _load_from_pkl(file_path)
elif f.endswith('.mat'):
arr = _load_from_mat(file_path)
else:
arr = None
if isinstance(arr, np.ndarray):
arr = _normalize_shape(arr)
# Consolidate into list of samples
if arr.ndim == 3:
# (N, rows, cols)
for i in range(arr.shape[0]):
specs.append(arr[i])
elif arr.ndim == 4:
# (N, C, rows, cols)
for i in range(arr.shape[0]):
specs.append(arr[i])
return np.array(specs) if specs else np.array([])
def tokenizer_train(
spectrograms,
max_len=None,
masking_percent=0.4,
mask=False,
seed=None,
metadata=None,
dataset_stats=None,
normalization="dataset",
interleaved: bool = False,
show_progress: bool = True,
):
# Auto-calculate max_len if not provided
if max_len is None and len(spectrograms) > 0:
max_len = calculate_max_len_from_spectrogram(spectrograms[0])
print(f"Auto-calculated max_len: {max_len} (from spectrogram shape {spectrograms[0].shape})")
elif max_len is None:
max_len = 513 # fallback default
print(f"Using default max_len: {max_len}")
total_specs = len(spectrograms)
if show_progress:
print(f"Tokenizing {total_specs} samples...")
rng: Generator = default_rng(seed) if seed is not None else default_rng()
seq_groups = defaultdict(list)
tensor_samples = []
skipped_empty = 0
if metadata is not None:
meta_arrays = {k: np.asarray(v) for k, v in metadata.items()}
else:
meta_arrays = None
normalization = normalization or "dataset"
if normalization not in {"dataset", "per_sample"}:
raise ValueError(f"Unsupported normalization mode: {normalization}")
if dataset_stats is not None:
ds_mean = float(dataset_stats.get('mean', 0.0))
ds_std = float(dataset_stats.get('std', 1.0))
if abs(ds_std) < 1e-6:
ds_std = 1e-6
else:
ds_mean = 0.0
ds_std = 1.0
eps = 1e-6
iterator = spectrograms
if USE_TQDM and show_progress:
iterator = tqdm(spectrograms, desc="Tokenizing", total=total_specs)
for idx, spec in enumerate(iterator):
spec_np = np.array(spec, dtype=np.float32, copy=False)
mean_db = float(spec_np.mean())
std_db = float(spec_np.std())
if normalization == "per_sample":
denom = std_db if abs(std_db) > eps else eps
spec_proc = (spec_np - mean_db) / denom
else:
spec_proc = (spec_np - ds_mean) / ds_std
patch = patch_maker(spec_proc, interleaved=interleaved)
if patch.size == 0:
skipped_empty += 1
continue
n_patches = patch.shape[0]
patch_size = patch.shape[1] if patch.ndim > 1 else 16
n_masks = int(masking_percent * n_patches)
word2id = {
'[CLS]': np.full(patch_size, 0.2, dtype=np.float32),
'[MASK]': np.full(patch_size, 0.1, dtype=np.float32),
}
sample = make_sample(patch, word2id, n_masks, patch_size, mask=mask, rng=rng)
sample_meta = {}
if meta_arrays is not None:
for key, values in meta_arrays.items():
sample_meta[key] = values[idx]
sample_meta['power_stats'] = np.array([mean_db, std_db], dtype=np.float32)
if mask:
input_ids, masked_tokens, masked_pos = sample
seq_len = len(input_ids)
if seq_len <= 1:
continue
if masked_tokens:
masked_tokens = np.stack(masked_tokens).astype(np.float32, copy=False)
else:
masked_tokens = np.empty((0, patch_size), dtype=np.float32)
seq_groups[seq_len].append({
'input_ids': input_ids,
'masked_pos': masked_pos,
'masked_tokens': masked_tokens,
'n_patches': seq_len - 1,
**sample_meta,
})
else:
tensor_samples.append({
'sample': sample,
**sample_meta,
})
if skipped_empty:
print(f"⚠️ Skipped {skipped_empty} spectrograms with empty patches")
if mask:
filtered_data = {k: v for k, v in seq_groups.items() if k > 0 and v}
total_samples = sum(len(v) for v in filtered_data.values())
if not filtered_data:
print("Warning: No valid data after filtering!")
return {}
if show_progress:
print(f"βœ… Tokenization completed: {total_samples} samples across {len(filtered_data)} sequence lengths")
return {k: filtered_data[k] for k in sorted(filtered_data.keys())}
if not tensor_samples:
print("Warning: No validation data after processing!")
return torch.empty(0)
stacked = torch.stack([torch.tensor(item['sample'], dtype=torch.float32) if isinstance(item['sample'], np.ndarray)
else item['sample'] for item in tensor_samples])
if show_progress:
print(f"βœ… Tokenization completed: {len(tensor_samples)} validation samples")
return stacked
def calculate_max_len_from_spectrogram(spec, patch_rows=4, patch_cols=4):
"""
Calculate the maximum sequence length needed for a given spectrogram size.
Args:
spec: Spectrogram tensor/array
patch_rows: Number of rows per patch
patch_cols: Number of columns per patch
Returns:
int: Maximum sequence length (number of patches + 1 for CLS token)
"""
if hasattr(spec, 'shape'):
shape = spec.shape
else:
shape = spec
# Handle different shape formats
if len(shape) == 3 and shape[0] == 1: # [1, height, width]
n_rows, n_cols = shape[1], shape[2]
elif len(shape) == 4 and shape[0] == 1 and shape[1] == 1: # [1, 1, height, width]
n_rows, n_cols = shape[2], shape[3]
elif len(shape) == 2: # [height, width]
n_rows, n_cols = shape[0], shape[1]
else:
raise ValueError(f"Unexpected spec shape: {shape}")
n_patches_r = n_rows // patch_rows
n_patches_c = n_cols // patch_cols
total_patches = n_patches_r * n_patches_c
return total_patches + 1 # +1 for CLS token
def patch_maker(spec, patch_rows=4, patch_cols=4, interleaved: bool = False):
# Handle normalized spectrograms: [1, height, width] or [1, 1, height, width]
if len(spec.shape) == 3 and spec.shape[0] == 1: # [1, height, width]
spec = spec.squeeze(0) # Remove batch dimension: [height, width]
elif len(spec.shape) == 4 and spec.shape[0] == 1 and spec.shape[1] == 1: # [1, 1, height, width]
spec = spec.squeeze(0).squeeze(0) # Remove both dimensions: [height, width]
elif len(spec.shape) == 2: # [height, width] - already processed
pass
else:
raise ValueError(f"Unexpected spec shape: {spec.shape}")
n_rows, n_cols = spec.shape
if interleaved:
# Treat last axis as interleaved [real, imag, real, imag, ...]
# Compute patches across columns in pairs (2x per complex bin)
n_patches_r = n_rows // patch_rows
n_complex_cols = n_cols // 2
n_patches_c = n_complex_cols // patch_cols
if n_patches_r == 0 or n_patches_c == 0:
print(f"❌ PATCH CREATION FAILED (interleaved): {n_rows}x{n_cols} too small for {patch_rows}x{patch_cols}")
return np.array([])
# Crop to full patches: rows and 2x columns for interleaving
cropped = spec[:n_patches_r * patch_rows, :n_patches_c * patch_cols * 2]
if cropped.size == 0:
print(f"⚠️ No patches generated from {n_rows}x{n_cols} spectrogram (interleaved)")
return np.array([])
# Reshape to (n_patches_r, patch_rows, n_patches_c, patch_cols*2)
reshaped = cropped.reshape(n_patches_r, patch_rows, n_patches_c, patch_cols * 2)
result = reshaped.transpose(0, 2, 1, 3).reshape(-1, patch_rows * patch_cols * 2)
return result.astype(np.float32, copy=False)
# Non-interleaved real-valued path (existing behavior)
n_patches_r, n_patches_c = n_rows // patch_rows, n_cols // patch_cols
if n_patches_r == 0 or n_patches_c == 0:
print(f"❌ PATCH CREATION FAILED: spectrogram {n_rows}x{n_cols} too small for {patch_rows}x{patch_cols} patches")
print(f" n_patches_r: {n_patches_r}, n_patches_c: {n_patches_c}")
return np.array([])
cropped = spec[:n_patches_r * patch_rows, :n_patches_c * patch_cols]
if cropped.size == 0:
print(f"⚠️ No patches generated from {n_rows}x{n_cols} spectrogram")
return np.array([])
reshaped = cropped.reshape(n_patches_r, patch_rows, n_patches_c, patch_cols)
result = reshaped.transpose(0, 2, 1, 3).reshape(-1, patch_rows * patch_cols)
return result.astype(np.float32, copy=False)
def make_sample(tokens, word2id, n_masks, patch_size, mask=True, rng: Generator | None = None):
rng = rng or default_rng()
input_ids = np.vstack((word2id['[CLS]'], tokens))
if not mask:
return torch.tensor(input_ids, dtype=torch.float32)
n_patches = tokens.shape[0]
if n_masks <= 0 or n_patches == 0:
masked_pos = np.empty(0, dtype=np.int64)
else:
n_masks = min(n_masks, n_patches)
mask_candidates = np.arange(1, n_patches + 1)
masked_pos = rng.choice(mask_candidates, size=n_masks, replace=False)
masked_tokens = []
for pos in masked_pos:
masked_tokens.append(input_ids[pos].astype(np.float32, copy=True))
rnd = rng.random()
if rnd < 0.1:
input_ids[pos] = rng.random(patch_size, dtype=np.float32)
elif rnd < 0.9:
input_ids[pos] = word2id['[MASK]']
return [input_ids.astype(np.float32, copy=False), masked_tokens, masked_pos]
def patch_reconstructor(patches, rows, cols, patch_rows=4, patch_cols=4):
if isinstance(patches, torch.Tensor): patches = patches.detach().cpu().numpy()
batch_size, num_patches, _ = patches.shape
n_h, n_w = rows // patch_rows, cols // patch_cols
patches = patches.reshape(batch_size, n_h, n_w, patch_rows, patch_cols)
reconstructed = np.zeros((batch_size, rows, cols))
for i in range(n_h):
for j in range(n_w):
reconstructed[:, i*patch_rows:(i+1)*patch_rows, j*patch_cols:(j+1)*patch_cols] = patches[:, i, j]
return reconstructed
def plot_radar_chart(names, opt_scores, base_scores, save_path="results/chart.png"):
try:
import matplotlib.pyplot as plt
from math import pi
N = len(names)
angles = [n/float(N)*2*pi for n in range(N)] + [0]
fig, ax = plt.subplots(subplot_kw=dict(projection='polar'))
ax.plot(angles, opt_scores + opt_scores[:1], 'o-', label='Optimized', color='#1f77b4')
ax.fill(angles, opt_scores + opt_scores[:1], alpha=0.25, color='#1f77b4')
ax.plot(angles, base_scores + base_scores[:1], 'o-', label='Baseline', color='#ff7f0e')
ax.fill(angles, base_scores + base_scores[:1], alpha=0.25, color='#ff7f0e')
ax.set_xticks(angles[:-1]); ax.set_xticklabels(names)
ax.set_ylim(0, 1); ax.legend(); ax.grid(True, alpha=0.3)
plt.savefig(save_path, dpi=300, bbox_inches='tight'); plt.close()
print(f"πŸ“Š Chart saved: {save_path}")
except: print("⚠️ Matplotlib unavailable")
class MaskedSpectrogramDataset(torch.utils.data.Dataset):
"""Lazy dataset that materializes masked spectrogram samples per access."""
def __init__(self, samples):
self.samples = samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
input_ids = torch.from_numpy(sample['input_ids']).float()
masked_tokens = torch.from_numpy(sample['masked_tokens']).float()
masked_pos = torch.from_numpy(sample['masked_pos']).long()
snr_db = torch.tensor(sample.get('snr_db', 0.0), dtype=torch.float32)
doppler_id = torch.tensor(sample.get('doppler_id', 0), dtype=torch.long)
power_stats = torch.tensor(sample.get('power_stats', np.zeros(2, dtype=np.float32)), dtype=torch.float32)
snr_id = torch.tensor(sample.get('snr_id', -1), dtype=torch.long)
modulation_id = torch.tensor(sample.get('modulation_id', -1), dtype=torch.long)
return (
input_ids,
masked_tokens,
masked_pos,
snr_db,
doppler_id,
power_stats,
snr_id,
modulation_id,
)
def create_train_dataloader(data, batch_size, shuffle, num_workers=0):
loaders = {}
for seq_len, group in data.items():
print(f"Dataloader: Processing seq_len={seq_len} with {len(group)} samples")
# Expect labels to be provided as group_labels in data if available
group_labels = None
if isinstance(group, tuple) and len(group) == 2:
group, group_labels = group
# Masked data with dict structure
if isinstance(group[0], dict):
print(" Processing as masked data (dict structure)")
dataset = MaskedSpectrogramDataset(group)
loaders[seq_len] = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
num_workers=num_workers,
)
print(f" Created DataLoader with {len(dataset)} samples (lazy loading)")
elif isinstance(group[0], list):
print(" Processing as masked data (list structure)")
ids, tokens, pos = zip(*group)
# If labels are available, use them; else, use zeros
if group_labels is not None:
label_tensor = torch.tensor(group_labels, dtype=torch.long)
else:
label_tensor = torch.zeros(len(group), dtype=torch.long)
dataset = TensorDataset(torch.tensor(ids, dtype=torch.float32),
torch.tensor(tokens, dtype=torch.float32),
torch.tensor(pos, dtype=torch.long),
label_tensor)
loaders[seq_len] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers)
print(f" Created DataLoader with {len(dataset)} samples (with labels)")
else:
print(" Processing as non-masked data")
if isinstance(group[0], torch.Tensor):
dataset = TensorDataset(*group)
else:
tensor_group = [torch.tensor(g, dtype=torch.float32) for g in group]
dataset = TensorDataset(*tensor_group)
loaders[seq_len] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers)
print(f" Created DataLoader with {len(dataset)} samples")
return loaders
def train_lwm(
model,
train_loaders,
val_loaders,
optimizer,
scheduler,
epochs,
device,
save_dir="models",
log_file="training_log.csv",
checkpoint_suffix: str = "",
distributed_context: Optional[Dict[str, Any]] = None,
):
distributed_context = distributed_context or {}
is_distributed = distributed_context.get("is_distributed", False)
rank = distributed_context.get("rank", 0)
world_size = max(1, distributed_context.get("world_size", 1))
is_primary = distributed_context.get("is_primary", rank == 0)
os.makedirs(save_dir, exist_ok=True)
# Initialize logging
log_file_path = f"{save_dir}/training_log.csv"
use_tensorboard = False
writer = None
# Try to initialize TensorBoard writer
if is_primary:
try:
from torch.utils.tensorboard import SummaryWriter
tensorboard_dir = f"{save_dir}/tensorboard"
writer = SummaryWriter(tensorboard_dir)
print(f"πŸ“Š TensorBoard logs will be saved to: {tensorboard_dir}")
use_tensorboard = True
except (ImportError, AttributeError) as e:
print(f"⚠️ TensorBoard not available ({e}), using CSV logging instead")
# Initialize CSV logging as fallback
with open(log_file_path, 'w') as f:
f.write("epoch,train_loss,val_loss,val_nmse,lr\n")
criterion = nn.MSELoss(reduction='sum')
best_mse = float('inf')
train_losses, val_losses, val_nmse_losses = [], [], []
# Early stopping parameters
patience = 3 # Stop if no improvement for 3 epochs
patience_counter = 0
def _sync_sum(value: float) -> float:
if not is_distributed or not dist.is_available() or not dist.is_initialized():
return float(value)
tensor = torch.tensor(value, dtype=torch.float64, device=device)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
return float(tensor.item())
for epoch in range(epochs):
# Training
model.train()
train_mse, train_samples = 0.0, 0
if is_primary:
print(f"\nEpoch {epoch+1}/{epochs}")
for loader in train_loaders.values():
pbar = tqdm(
loader,
desc="Train",
postfix={"loss": 0.0, "avg_loss": 0.0},
disable=not is_primary,
)
for batch in pbar:
optimizer.zero_grad()
if len(batch) >= 3:
ids, tokens, pos = batch[0], batch[1], batch[2]
else:
raise ValueError(f"Unexpected batch length: {len(batch)}")
ids = ids.to(device).float()
tokens = tokens.to(device).float()
pos = pos.to(device).long()
logits = model(ids, pos)[0]
loss = criterion(tokens, logits)
loss.backward(); optimizer.step(); scheduler.step()
train_mse += loss.item(); train_samples += ids.shape[0]
# Update tqdm postfix with real-time metrics
current_avg_loss = train_mse / max(train_samples, 1)
batch_size = ids.shape[0]
if is_primary:
pbar.set_postfix({
"loss": f"{loss.item()/batch_size:.4f}",
"avg_loss": f"{current_avg_loss:.4f}"
})
total_train_mse = _sync_sum(train_mse)
total_train_samples = _sync_sum(train_samples)
train_mse = total_train_mse / max(total_train_samples, 1)
train_losses.append(train_mse)
# Log training metrics
if use_tensorboard and writer:
writer.add_scalar('Loss/train', train_mse, epoch + 1)
writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], epoch + 1)
elif is_primary:
# Log to CSV
lr = optimizer.param_groups[0]['lr']
with open(log_file_path, 'a') as f:
f.write(f"{epoch+1},{train_mse},,,{lr}\n")
# Validation every epoch
model.eval()
val_mse, val_nmse, val_samples = 0.0, 0.0, 0
with torch.no_grad():
for loader in val_loaders.values():
progress_bar = tqdm(
loader,
desc="Val",
postfix={"mse": 0.0, "nmse": 0.0},
disable=not is_primary,
)
for batch in progress_bar:
# Check if validation data has masking (3 or 4 elements) or not (1 element)
if len(batch) >= 3:
# Masked validation data (training format)
ids, tokens, pos = batch[0], batch[1], batch[2]
ids = ids.to(device).float()
tokens = tokens.to(device).float()
pos = pos.to(device).long()
logits = model(ids, pos)[0]
elif len(batch) == 1:
# Non-masked validation data (tensor format)
val_tensor = batch[0].to(device, dtype=torch.float32) if 'mps' in str(device) else batch[0].to(device)
# For validation, call model without masked_pos (None)
output = model(val_tensor) # Returns [batch_size, seq_len, d_model]
# Apply decoder to get predictions in original dimension
# Handle DataParallel wrapper
model_module = model.module if hasattr(model, 'module') else model
logits = model_module.decoder(output) + model_module.decoder_bias # [batch_size, seq_len, element_length]
# For non-masked validation, tokens = input (no masking applied)
tokens = val_tensor
ids = val_tensor
else:
raise ValueError(f"Unexpected batch length: {len(batch)}")
val_mse += criterion(tokens, logits).item()
# Safe numpy conversion for MPS compatibility
tokens_np = tokens.float().cpu().numpy().astype(np.float32) if 'mps' in str(device) else tokens.cpu().numpy()
logits_np = logits.float().cpu().numpy().astype(np.float32) if 'mps' in str(device) else logits.cpu().numpy()
nmse_val = nmse_loss(tokens_np, logits_np)
val_nmse += nmse_val * ids.shape[0]
val_samples += ids.shape[0]
# Update progress bar with real-time metrics
current_mse = val_mse / max(val_samples, 1)
current_nmse = val_nmse / max(val_samples, 1)
current_nmse_db = 10 * np.log10(max(current_nmse, 1e-8)) # Convert to dB scale
batch_size = ids.shape[0]
if is_primary:
progress_bar.set_postfix({
"mse": f"{current_mse:.4f}",
"nmse": f"{current_nmse_db:.2f}dB"
})
total_val_mse = _sync_sum(val_mse)
total_val_nmse = _sync_sum(val_nmse)
total_val_samples = _sync_sum(val_samples)
val_mse = total_val_mse / max(total_val_samples, 1)
val_nmse = total_val_nmse / max(total_val_samples, 1)
val_losses.append(val_mse)
val_nmse_losses.append(val_nmse)
# Log validation metrics
if use_tensorboard and writer:
writer.add_scalar('Loss/validation', val_mse, epoch + 1)
writer.add_scalar('Loss/nmse', val_nmse, epoch + 1)
elif is_primary:
# Update CSV with validation metrics
lr = optimizer.param_groups[0]['lr']
# Read the last line and update it with validation metrics
with open(log_file_path, 'r') as f:
lines = f.readlines()
if lines:
# Update the last line with validation metrics
last_line = lines[-1].strip()
parts = last_line.split(',')
if len(parts) >= 5:
parts[2] = f"{val_mse}"
parts[3] = f"{val_nmse}"
lines[-1] = ','.join(parts) + '\n'
with open(log_file_path, 'w') as f:
f.writelines(lines)
if val_mse < best_mse:
best_mse = val_mse
patience_counter = 0 # Reset counter on improvement
suffix = checkpoint_suffix or ""
if is_primary:
path = f"{save_dir}/lwm_epoch{epoch+1}_val{val_mse:.4f}{suffix}.pth"
torch.save(model.state_dict(), path)
print(f"βœ… Saved: {path}")
else:
patience_counter += 1
if is_primary:
print(f"⏸️ No improvement for {patience_counter}/{patience} epochs")
# Early stopping check
if patience_counter >= patience:
if is_primary:
print(f"πŸ›‘ Early stopping triggered after {epoch+1} epochs")
print(f" Best validation MSE: {best_mse:.4f}")
break
if is_primary:
print(f"Train MSE: {train_mse:.4f}")
val_nmse_db = 10 * np.log10(max(val_nmse, 1e-8))
print(f"Val MSE: {val_mse:.4f}, NMSE: {val_nmse_db:.2f}dB")
# Ensure val_losses and val_nmse_losses have same length as train_losses
# Fill missing validation data with None or last available value
while len(val_losses) < len(train_losses):
val_losses.append(None)
while len(val_nmse_losses) < len(train_losses):
val_nmse_losses.append(None)
# Save training history
# Convert numpy types to Python native types for JSON serialization
def convert_numpy_types(obj):
"""Convert numpy types to Python native types for JSON serialization"""
if isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, list):
return [convert_numpy_types(item) for item in obj]
elif isinstance(obj, dict):
return {key: convert_numpy_types(value) for key, value in obj.items()}
else:
return obj
training_history = {
'train_losses': convert_numpy_types(train_losses),
'val_losses': convert_numpy_types(val_losses),
'val_nmse_losses': convert_numpy_types(val_nmse_losses),
'epochs': list(range(1, epochs + 1)),
'best_val_mse': convert_numpy_types(best_mse)
}
if is_primary:
import json
history_file = f"{save_dir}/training_history.json"
with open(history_file, 'w') as f:
json.dump(training_history, f, indent=2)
print(f"πŸ“Š Training history saved: {history_file}")
# Close TensorBoard writer
if use_tensorboard and writer:
writer.close()
print(f"πŸ“Š TensorBoard logs saved: {tensorboard_dir}")
else:
print(f"πŸ“Š Training logs saved: {log_file_path}")
elif use_tensorboard and writer:
writer.close()
return model
def nmse_loss(y_true, y_pred):
if isinstance(y_true, torch.Tensor):
mse = torch.mean((y_true - y_pred) ** 2)
power = torch.mean(y_true ** 2)
else:
mse = np.mean((y_true - y_pred) ** 2)
power = np.mean(y_true ** 2)
return mse / (power + 1e-8)