Spaces:
Running
Running
| import torch, torch.nn as nn, numpy as np, os, pickle, platform | |
| import torch.distributed as dist | |
| from typing import Optional, Dict, Any | |
| from numpy.random import Generator, default_rng | |
| try: | |
| from tqdm import tqdm # type: ignore | |
| except ImportError: # pragma: no cover - optional dependency | |
| def tqdm(iterable, *args, **kwargs): | |
| return iterable | |
| # Optional deps for MATLAB .mat (v7.3 HDF5) loading | |
| try: | |
| import h5py # type: ignore | |
| except Exception: | |
| h5py = None # Fallback handled below | |
| try: | |
| from scipy.io import loadmat # type: ignore | |
| except Exception: | |
| loadmat = None # Only used if available | |
| from collections import defaultdict | |
| from torch.utils.data import TensorDataset, DataLoader | |
| # Use tqdm for better progress display | |
| USE_TQDM = True | |
| def count_parameters(model, log: bool = True): | |
| total = sum(p.numel() for p in model.parameters()) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| if log: | |
| print(f"π Model: {total:,} total, {trainable:,} trainable") | |
| return total | |
| def generate_spectrograms_and_labels(scenario_name, spectrogram_path, cache_path): | |
| # TEMP FIX: Skip cache if cache_path is None | |
| if cache_path and os.path.exists(cache_path): | |
| with open(cache_path, 'rb') as f: | |
| cached_data = pickle.load(f) | |
| # Handle different cache formats | |
| if isinstance(cached_data, dict) and 'samples' in cached_data: | |
| spectrograms = cached_data['samples'] | |
| else: | |
| spectrograms = cached_data | |
| else: | |
| # Load data directly if cache doesn't exist or cache_path is None | |
| spectrograms = load_spectrogram_data(spectrogram_path) | |
| # Create cache file (only if cache_path is provided) | |
| if cache_path: | |
| os.makedirs(os.path.dirname(cache_path), exist_ok=True) | |
| with open(cache_path, 'wb') as f: | |
| pickle.dump(spectrograms, f) | |
| labels = torch.zeros(len(spectrograms), dtype=torch.long) | |
| # Convert list of tensors to single tensor if needed | |
| if isinstance(spectrograms, list): | |
| spectrograms = torch.stack(spectrograms) | |
| return spectrograms, labels | |
| def load_spectrogram_data(path): | |
| """Load spectrogram data from a .pkl, .mat file, or directory. | |
| Returns a numpy array with shape: | |
| - (N, rows, cols) for single-channel spectrograms | |
| - (N, C, rows, cols) for multi-channel spectrograms | |
| """ | |
| specs = [] | |
| def _load_from_pkl(file_path): | |
| with open(file_path, 'rb') as f: | |
| data = pickle.load(f) | |
| if isinstance(data, dict) and 'spectrograms' in data: | |
| arr = data['spectrograms'] | |
| if isinstance(arr, np.ndarray): | |
| return arr | |
| if isinstance(data, np.ndarray): | |
| return data | |
| return None | |
| def _load_from_mat(file_path): | |
| # Primary path: MATLAB v7.3 (HDF5) via h5py | |
| if h5py is not None: | |
| try: | |
| with h5py.File(file_path, 'r') as f: | |
| # Prefer 'spectrograms'; otherwise pick the largest numeric dataset | |
| if 'spectrograms' in f: | |
| ds = f['spectrograms'] | |
| else: | |
| cand = [] | |
| def _collect(name, obj): | |
| try: | |
| if isinstance(obj, h5py.Dataset) and obj.dtype.kind in ('f','i','u','c','V'): | |
| cand.append((name, obj)) | |
| except Exception: | |
| pass | |
| f.visititems(_collect) | |
| if not cand: | |
| return None | |
| # pick the dataset with the most elements | |
| name, ds = max(cand, key=lambda kv: np.prod(kv[1].shape) if hasattr(kv[1], 'shape') else 0) | |
| # Complex handling: structured dtype with fields 'real'/'imag' or native complex dtype | |
| if hasattr(ds.dtype, 'fields') and ds.dtype.fields and 'real' in ds.dtype.fields and 'imag' in ds.dtype.fields: | |
| real = ds['real'][...] | |
| imag = ds['imag'][...] | |
| arr = real + 1j * imag | |
| else: | |
| arr = ds[...] | |
| return np.array(arr) | |
| except Exception: | |
| # Fallback to scipy if available | |
| pass | |
| # Fallback path: older MATLAB formats via scipy.io.loadmat | |
| if loadmat is not None: | |
| try: | |
| data = loadmat(file_path) | |
| # Prefer exact key; else choose first suitable numeric array | |
| if 'spectrograms' in data: | |
| arr = data['spectrograms'] | |
| return np.array(arr) | |
| for k, v in data.items(): | |
| if k.startswith('__'): | |
| continue | |
| if isinstance(v, np.ndarray) and v.ndim >= 2 and v.size > 0 and np.issubdtype(v.dtype, np.number): | |
| return np.array(v) | |
| except Exception: | |
| pass | |
| return None | |
| def _normalize_shape(arr: np.ndarray) -> np.ndarray: | |
| """Normalize array to (N, rows, cols) or (N, C, rows, cols). | |
| Handles both MATLAB-saved HDF5 layouts and already-normalized tensors: | |
| - (rows, cols) -> (1, rows, cols) | |
| - (rows, cols, N) -> (N, rows, cols) | |
| - (N, rows, cols) -> (N, rows, cols) | |
| - (rows, cols, C, N) -> (N, C, rows, cols) | |
| - (N, C, rows, cols) -> (N, C, rows, cols) | |
| """ | |
| if arr.ndim == 2: | |
| return arr[None, ...] | |
| if arr.ndim == 3: | |
| # Heuristic: if last dim looks like N, transpose; else assume already (N, rows, cols) | |
| if arr.shape[2] > 4 and arr.shape[0] <= 512 and arr.shape[1] <= 512: | |
| return np.transpose(arr, (2, 0, 1)) | |
| else: | |
| return arr | |
| if arr.ndim == 4: | |
| # Two common patterns: (rows, cols, C, N) or (N, C, rows, cols) | |
| # Detect by which axis likely holds N (#samples) | |
| # If first axis is large and second is small (#channels), likely already (N, C, rows, cols) | |
| if arr.shape[0] > 4 and arr.shape[1] in (1, 2, 4, 8, 16, 32): | |
| return arr | |
| # Else if last axis is large (N) and third axis is small (C), transpose | |
| if arr.shape[3] > 4 and arr.shape[2] in (1, 2, 4, 8, 16, 32): | |
| return np.transpose(arr, (3, 2, 0, 1)) | |
| # Fallback to original assumption | |
| return np.transpose(arr, (3, 2, 0, 1)) | |
| return arr | |
| # File path | |
| if os.path.isfile(path): | |
| if path.endswith('.pkl'): | |
| arr = _load_from_pkl(path) | |
| if arr is not None: | |
| arr = _normalize_shape(arr) | |
| return arr | |
| if path.endswith('.mat'): | |
| arr = _load_from_mat(path) | |
| if arr is not None: | |
| arr = _normalize_shape(arr) | |
| return arr | |
| return np.array([]) | |
| # Directory path | |
| for root, _, files in os.walk(path): | |
| for f in files: | |
| file_path = os.path.join(root, f) | |
| if f.endswith('.pkl'): | |
| arr = _load_from_pkl(file_path) | |
| elif f.endswith('.mat'): | |
| arr = _load_from_mat(file_path) | |
| else: | |
| arr = None | |
| if isinstance(arr, np.ndarray): | |
| arr = _normalize_shape(arr) | |
| # Consolidate into list of samples | |
| if arr.ndim == 3: | |
| # (N, rows, cols) | |
| for i in range(arr.shape[0]): | |
| specs.append(arr[i]) | |
| elif arr.ndim == 4: | |
| # (N, C, rows, cols) | |
| for i in range(arr.shape[0]): | |
| specs.append(arr[i]) | |
| return np.array(specs) if specs else np.array([]) | |
| def tokenizer_train( | |
| spectrograms, | |
| max_len=None, | |
| masking_percent=0.4, | |
| mask=False, | |
| seed=None, | |
| metadata=None, | |
| dataset_stats=None, | |
| normalization="dataset", | |
| interleaved: bool = False, | |
| show_progress: bool = True, | |
| ): | |
| # Auto-calculate max_len if not provided | |
| if max_len is None and len(spectrograms) > 0: | |
| max_len = calculate_max_len_from_spectrogram(spectrograms[0]) | |
| print(f"Auto-calculated max_len: {max_len} (from spectrogram shape {spectrograms[0].shape})") | |
| elif max_len is None: | |
| max_len = 513 # fallback default | |
| print(f"Using default max_len: {max_len}") | |
| total_specs = len(spectrograms) | |
| if show_progress: | |
| print(f"Tokenizing {total_specs} samples...") | |
| rng: Generator = default_rng(seed) if seed is not None else default_rng() | |
| seq_groups = defaultdict(list) | |
| tensor_samples = [] | |
| skipped_empty = 0 | |
| if metadata is not None: | |
| meta_arrays = {k: np.asarray(v) for k, v in metadata.items()} | |
| else: | |
| meta_arrays = None | |
| normalization = normalization or "dataset" | |
| if normalization not in {"dataset", "per_sample"}: | |
| raise ValueError(f"Unsupported normalization mode: {normalization}") | |
| if dataset_stats is not None: | |
| ds_mean = float(dataset_stats.get('mean', 0.0)) | |
| ds_std = float(dataset_stats.get('std', 1.0)) | |
| if abs(ds_std) < 1e-6: | |
| ds_std = 1e-6 | |
| else: | |
| ds_mean = 0.0 | |
| ds_std = 1.0 | |
| eps = 1e-6 | |
| iterator = spectrograms | |
| if USE_TQDM and show_progress: | |
| iterator = tqdm(spectrograms, desc="Tokenizing", total=total_specs) | |
| for idx, spec in enumerate(iterator): | |
| spec_np = np.array(spec, dtype=np.float32, copy=False) | |
| mean_db = float(spec_np.mean()) | |
| std_db = float(spec_np.std()) | |
| if normalization == "per_sample": | |
| denom = std_db if abs(std_db) > eps else eps | |
| spec_proc = (spec_np - mean_db) / denom | |
| else: | |
| spec_proc = (spec_np - ds_mean) / ds_std | |
| patch = patch_maker(spec_proc, interleaved=interleaved) | |
| if patch.size == 0: | |
| skipped_empty += 1 | |
| continue | |
| n_patches = patch.shape[0] | |
| patch_size = patch.shape[1] if patch.ndim > 1 else 16 | |
| n_masks = int(masking_percent * n_patches) | |
| word2id = { | |
| '[CLS]': np.full(patch_size, 0.2, dtype=np.float32), | |
| '[MASK]': np.full(patch_size, 0.1, dtype=np.float32), | |
| } | |
| sample = make_sample(patch, word2id, n_masks, patch_size, mask=mask, rng=rng) | |
| sample_meta = {} | |
| if meta_arrays is not None: | |
| for key, values in meta_arrays.items(): | |
| sample_meta[key] = values[idx] | |
| sample_meta['power_stats'] = np.array([mean_db, std_db], dtype=np.float32) | |
| if mask: | |
| input_ids, masked_tokens, masked_pos = sample | |
| seq_len = len(input_ids) | |
| if seq_len <= 1: | |
| continue | |
| if masked_tokens: | |
| masked_tokens = np.stack(masked_tokens).astype(np.float32, copy=False) | |
| else: | |
| masked_tokens = np.empty((0, patch_size), dtype=np.float32) | |
| seq_groups[seq_len].append({ | |
| 'input_ids': input_ids, | |
| 'masked_pos': masked_pos, | |
| 'masked_tokens': masked_tokens, | |
| 'n_patches': seq_len - 1, | |
| **sample_meta, | |
| }) | |
| else: | |
| tensor_samples.append({ | |
| 'sample': sample, | |
| **sample_meta, | |
| }) | |
| if skipped_empty: | |
| print(f"β οΈ Skipped {skipped_empty} spectrograms with empty patches") | |
| if mask: | |
| filtered_data = {k: v for k, v in seq_groups.items() if k > 0 and v} | |
| total_samples = sum(len(v) for v in filtered_data.values()) | |
| if not filtered_data: | |
| print("Warning: No valid data after filtering!") | |
| return {} | |
| if show_progress: | |
| print(f"β Tokenization completed: {total_samples} samples across {len(filtered_data)} sequence lengths") | |
| return {k: filtered_data[k] for k in sorted(filtered_data.keys())} | |
| if not tensor_samples: | |
| print("Warning: No validation data after processing!") | |
| return torch.empty(0) | |
| stacked = torch.stack([torch.tensor(item['sample'], dtype=torch.float32) if isinstance(item['sample'], np.ndarray) | |
| else item['sample'] for item in tensor_samples]) | |
| if show_progress: | |
| print(f"β Tokenization completed: {len(tensor_samples)} validation samples") | |
| return stacked | |
| def calculate_max_len_from_spectrogram(spec, patch_rows=4, patch_cols=4): | |
| """ | |
| Calculate the maximum sequence length needed for a given spectrogram size. | |
| Args: | |
| spec: Spectrogram tensor/array | |
| patch_rows: Number of rows per patch | |
| patch_cols: Number of columns per patch | |
| Returns: | |
| int: Maximum sequence length (number of patches + 1 for CLS token) | |
| """ | |
| if hasattr(spec, 'shape'): | |
| shape = spec.shape | |
| else: | |
| shape = spec | |
| # Handle different shape formats | |
| if len(shape) == 3 and shape[0] == 1: # [1, height, width] | |
| n_rows, n_cols = shape[1], shape[2] | |
| elif len(shape) == 4 and shape[0] == 1 and shape[1] == 1: # [1, 1, height, width] | |
| n_rows, n_cols = shape[2], shape[3] | |
| elif len(shape) == 2: # [height, width] | |
| n_rows, n_cols = shape[0], shape[1] | |
| else: | |
| raise ValueError(f"Unexpected spec shape: {shape}") | |
| n_patches_r = n_rows // patch_rows | |
| n_patches_c = n_cols // patch_cols | |
| total_patches = n_patches_r * n_patches_c | |
| return total_patches + 1 # +1 for CLS token | |
| def patch_maker(spec, patch_rows=4, patch_cols=4, interleaved: bool = False): | |
| # Handle normalized spectrograms: [1, height, width] or [1, 1, height, width] | |
| if len(spec.shape) == 3 and spec.shape[0] == 1: # [1, height, width] | |
| spec = spec.squeeze(0) # Remove batch dimension: [height, width] | |
| elif len(spec.shape) == 4 and spec.shape[0] == 1 and spec.shape[1] == 1: # [1, 1, height, width] | |
| spec = spec.squeeze(0).squeeze(0) # Remove both dimensions: [height, width] | |
| elif len(spec.shape) == 2: # [height, width] - already processed | |
| pass | |
| else: | |
| raise ValueError(f"Unexpected spec shape: {spec.shape}") | |
| n_rows, n_cols = spec.shape | |
| if interleaved: | |
| # Treat last axis as interleaved [real, imag, real, imag, ...] | |
| # Compute patches across columns in pairs (2x per complex bin) | |
| n_patches_r = n_rows // patch_rows | |
| n_complex_cols = n_cols // 2 | |
| n_patches_c = n_complex_cols // patch_cols | |
| if n_patches_r == 0 or n_patches_c == 0: | |
| print(f"β PATCH CREATION FAILED (interleaved): {n_rows}x{n_cols} too small for {patch_rows}x{patch_cols}") | |
| return np.array([]) | |
| # Crop to full patches: rows and 2x columns for interleaving | |
| cropped = spec[:n_patches_r * patch_rows, :n_patches_c * patch_cols * 2] | |
| if cropped.size == 0: | |
| print(f"β οΈ No patches generated from {n_rows}x{n_cols} spectrogram (interleaved)") | |
| return np.array([]) | |
| # Reshape to (n_patches_r, patch_rows, n_patches_c, patch_cols*2) | |
| reshaped = cropped.reshape(n_patches_r, patch_rows, n_patches_c, patch_cols * 2) | |
| result = reshaped.transpose(0, 2, 1, 3).reshape(-1, patch_rows * patch_cols * 2) | |
| return result.astype(np.float32, copy=False) | |
| # Non-interleaved real-valued path (existing behavior) | |
| n_patches_r, n_patches_c = n_rows // patch_rows, n_cols // patch_cols | |
| if n_patches_r == 0 or n_patches_c == 0: | |
| print(f"β PATCH CREATION FAILED: spectrogram {n_rows}x{n_cols} too small for {patch_rows}x{patch_cols} patches") | |
| print(f" n_patches_r: {n_patches_r}, n_patches_c: {n_patches_c}") | |
| return np.array([]) | |
| cropped = spec[:n_patches_r * patch_rows, :n_patches_c * patch_cols] | |
| if cropped.size == 0: | |
| print(f"β οΈ No patches generated from {n_rows}x{n_cols} spectrogram") | |
| return np.array([]) | |
| reshaped = cropped.reshape(n_patches_r, patch_rows, n_patches_c, patch_cols) | |
| result = reshaped.transpose(0, 2, 1, 3).reshape(-1, patch_rows * patch_cols) | |
| return result.astype(np.float32, copy=False) | |
| def make_sample(tokens, word2id, n_masks, patch_size, mask=True, rng: Generator | None = None): | |
| rng = rng or default_rng() | |
| input_ids = np.vstack((word2id['[CLS]'], tokens)) | |
| if not mask: | |
| return torch.tensor(input_ids, dtype=torch.float32) | |
| n_patches = tokens.shape[0] | |
| if n_masks <= 0 or n_patches == 0: | |
| masked_pos = np.empty(0, dtype=np.int64) | |
| else: | |
| n_masks = min(n_masks, n_patches) | |
| mask_candidates = np.arange(1, n_patches + 1) | |
| masked_pos = rng.choice(mask_candidates, size=n_masks, replace=False) | |
| masked_tokens = [] | |
| for pos in masked_pos: | |
| masked_tokens.append(input_ids[pos].astype(np.float32, copy=True)) | |
| rnd = rng.random() | |
| if rnd < 0.1: | |
| input_ids[pos] = rng.random(patch_size, dtype=np.float32) | |
| elif rnd < 0.9: | |
| input_ids[pos] = word2id['[MASK]'] | |
| return [input_ids.astype(np.float32, copy=False), masked_tokens, masked_pos] | |
| def patch_reconstructor(patches, rows, cols, patch_rows=4, patch_cols=4): | |
| if isinstance(patches, torch.Tensor): patches = patches.detach().cpu().numpy() | |
| batch_size, num_patches, _ = patches.shape | |
| n_h, n_w = rows // patch_rows, cols // patch_cols | |
| patches = patches.reshape(batch_size, n_h, n_w, patch_rows, patch_cols) | |
| reconstructed = np.zeros((batch_size, rows, cols)) | |
| for i in range(n_h): | |
| for j in range(n_w): | |
| reconstructed[:, i*patch_rows:(i+1)*patch_rows, j*patch_cols:(j+1)*patch_cols] = patches[:, i, j] | |
| return reconstructed | |
| def plot_radar_chart(names, opt_scores, base_scores, save_path="results/chart.png"): | |
| try: | |
| import matplotlib.pyplot as plt | |
| from math import pi | |
| N = len(names) | |
| angles = [n/float(N)*2*pi for n in range(N)] + [0] | |
| fig, ax = plt.subplots(subplot_kw=dict(projection='polar')) | |
| ax.plot(angles, opt_scores + opt_scores[:1], 'o-', label='Optimized', color='#1f77b4') | |
| ax.fill(angles, opt_scores + opt_scores[:1], alpha=0.25, color='#1f77b4') | |
| ax.plot(angles, base_scores + base_scores[:1], 'o-', label='Baseline', color='#ff7f0e') | |
| ax.fill(angles, base_scores + base_scores[:1], alpha=0.25, color='#ff7f0e') | |
| ax.set_xticks(angles[:-1]); ax.set_xticklabels(names) | |
| ax.set_ylim(0, 1); ax.legend(); ax.grid(True, alpha=0.3) | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight'); plt.close() | |
| print(f"π Chart saved: {save_path}") | |
| except: print("β οΈ Matplotlib unavailable") | |
| class MaskedSpectrogramDataset(torch.utils.data.Dataset): | |
| """Lazy dataset that materializes masked spectrogram samples per access.""" | |
| def __init__(self, samples): | |
| self.samples = samples | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| sample = self.samples[idx] | |
| input_ids = torch.from_numpy(sample['input_ids']).float() | |
| masked_tokens = torch.from_numpy(sample['masked_tokens']).float() | |
| masked_pos = torch.from_numpy(sample['masked_pos']).long() | |
| snr_db = torch.tensor(sample.get('snr_db', 0.0), dtype=torch.float32) | |
| doppler_id = torch.tensor(sample.get('doppler_id', 0), dtype=torch.long) | |
| power_stats = torch.tensor(sample.get('power_stats', np.zeros(2, dtype=np.float32)), dtype=torch.float32) | |
| snr_id = torch.tensor(sample.get('snr_id', -1), dtype=torch.long) | |
| modulation_id = torch.tensor(sample.get('modulation_id', -1), dtype=torch.long) | |
| return ( | |
| input_ids, | |
| masked_tokens, | |
| masked_pos, | |
| snr_db, | |
| doppler_id, | |
| power_stats, | |
| snr_id, | |
| modulation_id, | |
| ) | |
| def create_train_dataloader(data, batch_size, shuffle, num_workers=0): | |
| loaders = {} | |
| for seq_len, group in data.items(): | |
| print(f"Dataloader: Processing seq_len={seq_len} with {len(group)} samples") | |
| # Expect labels to be provided as group_labels in data if available | |
| group_labels = None | |
| if isinstance(group, tuple) and len(group) == 2: | |
| group, group_labels = group | |
| # Masked data with dict structure | |
| if isinstance(group[0], dict): | |
| print(" Processing as masked data (dict structure)") | |
| dataset = MaskedSpectrogramDataset(group) | |
| loaders[seq_len] = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| pin_memory=True, | |
| num_workers=num_workers, | |
| ) | |
| print(f" Created DataLoader with {len(dataset)} samples (lazy loading)") | |
| elif isinstance(group[0], list): | |
| print(" Processing as masked data (list structure)") | |
| ids, tokens, pos = zip(*group) | |
| # If labels are available, use them; else, use zeros | |
| if group_labels is not None: | |
| label_tensor = torch.tensor(group_labels, dtype=torch.long) | |
| else: | |
| label_tensor = torch.zeros(len(group), dtype=torch.long) | |
| dataset = TensorDataset(torch.tensor(ids, dtype=torch.float32), | |
| torch.tensor(tokens, dtype=torch.float32), | |
| torch.tensor(pos, dtype=torch.long), | |
| label_tensor) | |
| loaders[seq_len] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers) | |
| print(f" Created DataLoader with {len(dataset)} samples (with labels)") | |
| else: | |
| print(" Processing as non-masked data") | |
| if isinstance(group[0], torch.Tensor): | |
| dataset = TensorDataset(*group) | |
| else: | |
| tensor_group = [torch.tensor(g, dtype=torch.float32) for g in group] | |
| dataset = TensorDataset(*tensor_group) | |
| loaders[seq_len] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers) | |
| print(f" Created DataLoader with {len(dataset)} samples") | |
| return loaders | |
| def train_lwm( | |
| model, | |
| train_loaders, | |
| val_loaders, | |
| optimizer, | |
| scheduler, | |
| epochs, | |
| device, | |
| save_dir="models", | |
| log_file="training_log.csv", | |
| checkpoint_suffix: str = "", | |
| distributed_context: Optional[Dict[str, Any]] = None, | |
| ): | |
| distributed_context = distributed_context or {} | |
| is_distributed = distributed_context.get("is_distributed", False) | |
| rank = distributed_context.get("rank", 0) | |
| world_size = max(1, distributed_context.get("world_size", 1)) | |
| is_primary = distributed_context.get("is_primary", rank == 0) | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Initialize logging | |
| log_file_path = f"{save_dir}/training_log.csv" | |
| use_tensorboard = False | |
| writer = None | |
| # Try to initialize TensorBoard writer | |
| if is_primary: | |
| try: | |
| from torch.utils.tensorboard import SummaryWriter | |
| tensorboard_dir = f"{save_dir}/tensorboard" | |
| writer = SummaryWriter(tensorboard_dir) | |
| print(f"π TensorBoard logs will be saved to: {tensorboard_dir}") | |
| use_tensorboard = True | |
| except (ImportError, AttributeError) as e: | |
| print(f"β οΈ TensorBoard not available ({e}), using CSV logging instead") | |
| # Initialize CSV logging as fallback | |
| with open(log_file_path, 'w') as f: | |
| f.write("epoch,train_loss,val_loss,val_nmse,lr\n") | |
| criterion = nn.MSELoss(reduction='sum') | |
| best_mse = float('inf') | |
| train_losses, val_losses, val_nmse_losses = [], [], [] | |
| # Early stopping parameters | |
| patience = 3 # Stop if no improvement for 3 epochs | |
| patience_counter = 0 | |
| def _sync_sum(value: float) -> float: | |
| if not is_distributed or not dist.is_available() or not dist.is_initialized(): | |
| return float(value) | |
| tensor = torch.tensor(value, dtype=torch.float64, device=device) | |
| dist.all_reduce(tensor, op=dist.ReduceOp.SUM) | |
| return float(tensor.item()) | |
| for epoch in range(epochs): | |
| # Training | |
| model.train() | |
| train_mse, train_samples = 0.0, 0 | |
| if is_primary: | |
| print(f"\nEpoch {epoch+1}/{epochs}") | |
| for loader in train_loaders.values(): | |
| pbar = tqdm( | |
| loader, | |
| desc="Train", | |
| postfix={"loss": 0.0, "avg_loss": 0.0}, | |
| disable=not is_primary, | |
| ) | |
| for batch in pbar: | |
| optimizer.zero_grad() | |
| if len(batch) >= 3: | |
| ids, tokens, pos = batch[0], batch[1], batch[2] | |
| else: | |
| raise ValueError(f"Unexpected batch length: {len(batch)}") | |
| ids = ids.to(device).float() | |
| tokens = tokens.to(device).float() | |
| pos = pos.to(device).long() | |
| logits = model(ids, pos)[0] | |
| loss = criterion(tokens, logits) | |
| loss.backward(); optimizer.step(); scheduler.step() | |
| train_mse += loss.item(); train_samples += ids.shape[0] | |
| # Update tqdm postfix with real-time metrics | |
| current_avg_loss = train_mse / max(train_samples, 1) | |
| batch_size = ids.shape[0] | |
| if is_primary: | |
| pbar.set_postfix({ | |
| "loss": f"{loss.item()/batch_size:.4f}", | |
| "avg_loss": f"{current_avg_loss:.4f}" | |
| }) | |
| total_train_mse = _sync_sum(train_mse) | |
| total_train_samples = _sync_sum(train_samples) | |
| train_mse = total_train_mse / max(total_train_samples, 1) | |
| train_losses.append(train_mse) | |
| # Log training metrics | |
| if use_tensorboard and writer: | |
| writer.add_scalar('Loss/train', train_mse, epoch + 1) | |
| writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], epoch + 1) | |
| elif is_primary: | |
| # Log to CSV | |
| lr = optimizer.param_groups[0]['lr'] | |
| with open(log_file_path, 'a') as f: | |
| f.write(f"{epoch+1},{train_mse},,,{lr}\n") | |
| # Validation every epoch | |
| model.eval() | |
| val_mse, val_nmse, val_samples = 0.0, 0.0, 0 | |
| with torch.no_grad(): | |
| for loader in val_loaders.values(): | |
| progress_bar = tqdm( | |
| loader, | |
| desc="Val", | |
| postfix={"mse": 0.0, "nmse": 0.0}, | |
| disable=not is_primary, | |
| ) | |
| for batch in progress_bar: | |
| # Check if validation data has masking (3 or 4 elements) or not (1 element) | |
| if len(batch) >= 3: | |
| # Masked validation data (training format) | |
| ids, tokens, pos = batch[0], batch[1], batch[2] | |
| ids = ids.to(device).float() | |
| tokens = tokens.to(device).float() | |
| pos = pos.to(device).long() | |
| logits = model(ids, pos)[0] | |
| elif len(batch) == 1: | |
| # Non-masked validation data (tensor format) | |
| val_tensor = batch[0].to(device, dtype=torch.float32) if 'mps' in str(device) else batch[0].to(device) | |
| # For validation, call model without masked_pos (None) | |
| output = model(val_tensor) # Returns [batch_size, seq_len, d_model] | |
| # Apply decoder to get predictions in original dimension | |
| # Handle DataParallel wrapper | |
| model_module = model.module if hasattr(model, 'module') else model | |
| logits = model_module.decoder(output) + model_module.decoder_bias # [batch_size, seq_len, element_length] | |
| # For non-masked validation, tokens = input (no masking applied) | |
| tokens = val_tensor | |
| ids = val_tensor | |
| else: | |
| raise ValueError(f"Unexpected batch length: {len(batch)}") | |
| val_mse += criterion(tokens, logits).item() | |
| # Safe numpy conversion for MPS compatibility | |
| tokens_np = tokens.float().cpu().numpy().astype(np.float32) if 'mps' in str(device) else tokens.cpu().numpy() | |
| logits_np = logits.float().cpu().numpy().astype(np.float32) if 'mps' in str(device) else logits.cpu().numpy() | |
| nmse_val = nmse_loss(tokens_np, logits_np) | |
| val_nmse += nmse_val * ids.shape[0] | |
| val_samples += ids.shape[0] | |
| # Update progress bar with real-time metrics | |
| current_mse = val_mse / max(val_samples, 1) | |
| current_nmse = val_nmse / max(val_samples, 1) | |
| current_nmse_db = 10 * np.log10(max(current_nmse, 1e-8)) # Convert to dB scale | |
| batch_size = ids.shape[0] | |
| if is_primary: | |
| progress_bar.set_postfix({ | |
| "mse": f"{current_mse:.4f}", | |
| "nmse": f"{current_nmse_db:.2f}dB" | |
| }) | |
| total_val_mse = _sync_sum(val_mse) | |
| total_val_nmse = _sync_sum(val_nmse) | |
| total_val_samples = _sync_sum(val_samples) | |
| val_mse = total_val_mse / max(total_val_samples, 1) | |
| val_nmse = total_val_nmse / max(total_val_samples, 1) | |
| val_losses.append(val_mse) | |
| val_nmse_losses.append(val_nmse) | |
| # Log validation metrics | |
| if use_tensorboard and writer: | |
| writer.add_scalar('Loss/validation', val_mse, epoch + 1) | |
| writer.add_scalar('Loss/nmse', val_nmse, epoch + 1) | |
| elif is_primary: | |
| # Update CSV with validation metrics | |
| lr = optimizer.param_groups[0]['lr'] | |
| # Read the last line and update it with validation metrics | |
| with open(log_file_path, 'r') as f: | |
| lines = f.readlines() | |
| if lines: | |
| # Update the last line with validation metrics | |
| last_line = lines[-1].strip() | |
| parts = last_line.split(',') | |
| if len(parts) >= 5: | |
| parts[2] = f"{val_mse}" | |
| parts[3] = f"{val_nmse}" | |
| lines[-1] = ','.join(parts) + '\n' | |
| with open(log_file_path, 'w') as f: | |
| f.writelines(lines) | |
| if val_mse < best_mse: | |
| best_mse = val_mse | |
| patience_counter = 0 # Reset counter on improvement | |
| suffix = checkpoint_suffix or "" | |
| if is_primary: | |
| path = f"{save_dir}/lwm_epoch{epoch+1}_val{val_mse:.4f}{suffix}.pth" | |
| torch.save(model.state_dict(), path) | |
| print(f"β Saved: {path}") | |
| else: | |
| patience_counter += 1 | |
| if is_primary: | |
| print(f"βΈοΈ No improvement for {patience_counter}/{patience} epochs") | |
| # Early stopping check | |
| if patience_counter >= patience: | |
| if is_primary: | |
| print(f"π Early stopping triggered after {epoch+1} epochs") | |
| print(f" Best validation MSE: {best_mse:.4f}") | |
| break | |
| if is_primary: | |
| print(f"Train MSE: {train_mse:.4f}") | |
| val_nmse_db = 10 * np.log10(max(val_nmse, 1e-8)) | |
| print(f"Val MSE: {val_mse:.4f}, NMSE: {val_nmse_db:.2f}dB") | |
| # Ensure val_losses and val_nmse_losses have same length as train_losses | |
| # Fill missing validation data with None or last available value | |
| while len(val_losses) < len(train_losses): | |
| val_losses.append(None) | |
| while len(val_nmse_losses) < len(train_losses): | |
| val_nmse_losses.append(None) | |
| # Save training history | |
| # Convert numpy types to Python native types for JSON serialization | |
| def convert_numpy_types(obj): | |
| """Convert numpy types to Python native types for JSON serialization""" | |
| if isinstance(obj, np.floating): | |
| return float(obj) | |
| elif isinstance(obj, np.integer): | |
| return int(obj) | |
| elif isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| elif isinstance(obj, list): | |
| return [convert_numpy_types(item) for item in obj] | |
| elif isinstance(obj, dict): | |
| return {key: convert_numpy_types(value) for key, value in obj.items()} | |
| else: | |
| return obj | |
| training_history = { | |
| 'train_losses': convert_numpy_types(train_losses), | |
| 'val_losses': convert_numpy_types(val_losses), | |
| 'val_nmse_losses': convert_numpy_types(val_nmse_losses), | |
| 'epochs': list(range(1, epochs + 1)), | |
| 'best_val_mse': convert_numpy_types(best_mse) | |
| } | |
| if is_primary: | |
| import json | |
| history_file = f"{save_dir}/training_history.json" | |
| with open(history_file, 'w') as f: | |
| json.dump(training_history, f, indent=2) | |
| print(f"π Training history saved: {history_file}") | |
| # Close TensorBoard writer | |
| if use_tensorboard and writer: | |
| writer.close() | |
| print(f"π TensorBoard logs saved: {tensorboard_dir}") | |
| else: | |
| print(f"π Training logs saved: {log_file_path}") | |
| elif use_tensorboard and writer: | |
| writer.close() | |
| return model | |
| def nmse_loss(y_true, y_pred): | |
| if isinstance(y_true, torch.Tensor): | |
| mse = torch.mean((y_true - y_pred) ** 2) | |
| power = torch.mean(y_true ** 2) | |
| else: | |
| mse = np.mean((y_true - y_pred) ** 2) | |
| power = np.mean(y_true ** 2) | |
| return mse / (power + 1e-8) | |