Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # /// script | |
| # dependencies = [ | |
| # "jax[cuda12]", | |
| # "equinox", | |
| # "scipy", | |
| # "jaxtyping", | |
| # ] | |
| # /// | |
| """Evaluate a trained Generator against held-out test samples. | |
| For each configuration we compute an 11-dimensional feature vector of physical | |
| observables. The Mahalanobis distance between the real and generated feature | |
| distributions gives a single scalar measure of model quality. | |
| Per-sample feature vector | |
| -------------------------- | |
| m, m^2, |m| magnetisation and its moments | |
| e, e^2 nearest-neighbour energy per spin (periodic BC) | |
| C(1..8) connected two-point correlation at r = 1, 2, 4, 8 | |
| s_mean/N mean cluster size (4-connected, open BC) | |
| s_max/N largest cluster size | |
| Ensemble statistics (printed for reference, not part of Mahalanobis) | |
| ---------------------------------------------------------------------- | |
| chi = N Β· Var(m) / T magnetic susceptibility | |
| C_v = N Β· Var(e) / TΒ² specific heat | |
| U4 = 1 β <m^4>/(3<m^2>^2) Binder cumulant | |
| β 2/3 in ordered phase | |
| β 0 in disordered phase | |
| β 0.47 at T_c for 2D Ising (Lββ) | |
| Distance | |
| -------- | |
| D = sqrt( ΞΞΌ^T Ξ£_real^{-1} ΞΞΌ ) | |
| where ΞΞΌ = ΞΌ_gen β ΞΌ_real and Ξ£_real is the sample covariance of the | |
| real test features. Per-feature z-scores ΞΞΌ_i / Ο_real_i are also | |
| reported so you can see which observables deviate most. | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| import numpy as np | |
| import scipy.ndimage | |
| import jax | |
| from tqdm.auto import tqdm | |
| from model import gen_config, snake_order | |
| from sample import load_checkpoint, sample_batch, tokens_to_grids | |
| from train import load_ising_data | |
| # --------------------------------------------------------------------------- | |
| # Physical constants | |
| # --------------------------------------------------------------------------- | |
| J = 1.0 | |
| T_C = 2.0 / np.log(1.0 + np.sqrt(2.0)) # exact: 2J / ln(1+β2) β 2.2692 | |
| FEATURE_NAMES = [ | |
| "m", "m^2", "|m|", | |
| "e", "e^2", | |
| "C(r=1)", "C(r=2)", "C(r=4)", "C(r=8)", | |
| "s_mean/N", "s_max/N", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Per-sample observables | |
| # --------------------------------------------------------------------------- | |
| def energy_per_spin(grid: np.ndarray) -> float: | |
| """Nearest-neighbour energy density with periodic boundary conditions. | |
| E/N = βJ/N Β· Ξ£_{β¨ijβ©} s_i s_j | |
| Each bond counted once via right- and down-shifts. | |
| """ | |
| right = np.roll(grid, -1, axis=1) | |
| down = np.roll(grid, -1, axis=0) | |
| return float(-J * (grid * right + grid * down).sum() / grid.size) | |
| def connected_correlations( | |
| grid: np.ndarray, | |
| distances: tuple[int, ...] = (1, 2, 4, 8), | |
| ) -> np.ndarray: | |
| """Isotropic connected two-point function C(r) = Β½[<s_x s_{x+r}> + <s_y s_{y+r}>] - <s>Β². | |
| Averaged over both spatial directions and all origin sites using | |
| periodic boundary conditions. | |
| """ | |
| m = float(grid.mean()) | |
| corr = [] | |
| for r in distances: | |
| cx = float((grid * np.roll(grid, r, axis=1)).mean()) | |
| cy = float((grid * np.roll(grid, r, axis=0)).mean()) | |
| corr.append((cx + cy) / 2.0 - m ** 2) | |
| return np.array(corr, dtype=np.float64) | |
| def cluster_stats(grid: np.ndarray) -> tuple[float, float]: | |
| """Mean and maximum cluster size for both spin species. | |
| Uses 4-connectivity (no diagonals) and open boundary conditions. | |
| Returns sizes normalised by the total number of spins so the result | |
| is independent of lattice size. | |
| Note: open BC means edge-spanning clusters are split at the boundary; | |
| this is applied consistently to both real and generated samples so | |
| systematic bias cancels in the Mahalanobis comparison. | |
| """ | |
| N = grid.size | |
| all_sizes: list[np.ndarray] = [] | |
| for spin in (1, -1): | |
| labeled, n_labels = scipy.ndimage.label(grid == spin) | |
| if n_labels > 0: | |
| # bincount index 0 is background; skip it | |
| all_sizes.append(np.bincount(labeled.ravel())[1:]) | |
| if not all_sizes: | |
| return 0.0, 0.0 | |
| sizes = np.concatenate(all_sizes).astype(np.float64) | |
| return float(sizes.mean()) / N, float(sizes.max()) / N | |
| def compute_features(grid: np.ndarray) -> np.ndarray: | |
| """Return the 11-D feature vector for a single Β±1 grid of shape (L, L).""" | |
| m = float(grid.mean()) | |
| e = energy_per_spin(grid) | |
| cr = connected_correlations(grid) | |
| s_mean, s_max = cluster_stats(grid) | |
| return np.array( | |
| [m, m ** 2, abs(m), e, e ** 2, *cr, s_mean, s_max], | |
| dtype=np.float64, | |
| ) | |
| def compute_feature_matrix(grids: np.ndarray, desc: str = "features") -> np.ndarray: | |
| """Compute the (N, 11) feature matrix for a batch of grids.""" | |
| return np.stack( | |
| [compute_features(grids[i]) | |
| for i in tqdm(range(len(grids)), desc=desc, leave=False)] | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Ensemble statistics | |
| # --------------------------------------------------------------------------- | |
| def ensemble_stats(X: np.ndarray, T: float = T_C) -> dict[str, float]: | |
| """Derive thermodynamic ensemble statistics from a feature matrix. | |
| Arguments | |
| --------- | |
| X : (N, 11) feature matrix from ``compute_feature_matrix``. | |
| T : temperature used for Ο and C_v normalisation. | |
| """ | |
| L = gen_config["lattice_size"] | |
| N = L * L | |
| m = X[:, FEATURE_NAMES.index("m")] | |
| m2 = X[:, FEATURE_NAMES.index("m^2")] | |
| m4 = m ** 4 | |
| e = X[:, FEATURE_NAMES.index("e")] | |
| chi = N * float(m.var()) / T | |
| Cv = N * float(e.var()) / T ** 2 | |
| binder = float(1.0 - m4.mean() / (3.0 * m2.mean() ** 2)) if m2.mean() > 0 else float("nan") | |
| return { | |
| "<|m|>": float(np.abs(m).mean()), | |
| "chi": chi, | |
| "C_v": Cv, | |
| "U4": binder, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Mahalanobis distance | |
| # --------------------------------------------------------------------------- | |
| def mahalanobis_distance( | |
| X_ref: np.ndarray, | |
| X_query: np.ndarray, | |
| reg: float = 1e-6, | |
| ) -> tuple[float, np.ndarray]: | |
| """Mahalanobis distance of the query-mean from the reference distribution. | |
| D = sqrt( ΞΞΌ^T Ξ£_ref^{-1} ΞΞΌ ) | |
| Also returns per-feature z-scores z_i = ΞΞΌ_i / Ο_ref_i, | |
| where Ο_ref_i = sqrt(Ξ£_ref[i,i]). |z_i| > 1 indicates a feature | |
| whose mean differs by more than one real-sample standard deviation. | |
| Parameters | |
| ---------- | |
| X_ref : (N, d) real / reference feature matrix | |
| X_query : (M, d) generated / query feature matrix | |
| reg : diagonal regularisation added to Ξ£_ref before inversion | |
| """ | |
| mu_ref = X_ref.mean(axis=0) | |
| mu_query = X_query.mean(axis=0) | |
| cov = np.cov(X_ref.T) + reg * np.eye(X_ref.shape[1]) | |
| cov_inv = np.linalg.inv(cov) | |
| delta = mu_query - mu_ref | |
| D = float(np.sqrt(max(0.0, delta @ cov_inv @ delta))) | |
| z_scores = delta / np.sqrt(np.diag(cov)) | |
| return D, z_scores | |
| # --------------------------------------------------------------------------- | |
| # Reporting | |
| # --------------------------------------------------------------------------- | |
| def print_feature_table(X_real: np.ndarray, X_gen: np.ndarray) -> None: | |
| mu_r = X_real.mean(axis=0) | |
| sd_r = X_real.std(axis=0) | |
| mu_g = X_gen.mean(axis=0) | |
| sd_g = X_gen.std(axis=0) | |
| col = 13 | |
| hdr = (f" {'Feature':<11} {'Real mean':>{col}} {'Real std':>{col}}" | |
| f" {'Gen mean':>{col}} {'Gen std':>{col}} {'z-score':>8}") | |
| print(hdr) | |
| print(" " + "β" * (len(hdr) - 2)) | |
| for name, mr, sr, mg, sg in zip(FEATURE_NAMES, mu_r, sd_r, mu_g, sd_g): | |
| z = (mg - mr) / (sr + 1e-12) | |
| flag = " <" if abs(z) > 1.0 else "" | |
| print(f" {name:<11} {mr:>{col}.4f} {sr:>{col}.4f}" | |
| f" {mg:>{col}.4f} {sg:>{col}.4f} {z:>+8.3f}{flag}") | |
| print() | |
| def print_ensemble_table(stats_real: dict, stats_gen: dict) -> None: | |
| labels = { | |
| "<|m|>": "mean |m|", | |
| "chi": "chi (susceptibility)", | |
| "C_v": "C_v (specific heat)", | |
| "U4": "U4 (Binder cumulant)", | |
| } | |
| print(f" {'Observable':<26} {'Real':>10} {'Generated':>10}") | |
| print(" " + "β" * 50) | |
| for key, label in labels.items(): | |
| r = stats_real[key] | |
| g = stats_gen[key] | |
| print(f" {label:<26} {r:>10.4f} {g:>10.4f}") | |
| print() | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| _SAMPLE_BATCH = 4 # fixed vmapped batch; changing triggers recompilation | |
| def generate_grids(model, n: int, key: jax.Array, L: int) -> np.ndarray: | |
| """Sample n grids in batches of _SAMPLE_BATCH with a progress bar. | |
| Using a fixed batch size means only one JIT compilation happens regardless | |
| of n. The final partial batch is padded then trimmed. | |
| """ | |
| batches = [] | |
| n_full, remainder = divmod(n, _SAMPLE_BATCH) | |
| n_batches = n_full + (1 if remainder else 0) | |
| with tqdm(total=n, unit="samples", desc="Sampling") as pbar: | |
| for i in range(n_batches): | |
| key, subkey = jax.random.split(key) | |
| tokens = np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey)) | |
| batches.append(tokens) | |
| pbar.update(min(_SAMPLE_BATCH, n - i * _SAMPLE_BATCH)) | |
| return tokens_to_grids(np.concatenate(batches)[:n], L) | |
| def load_test_grids( | |
| test_data: Path | None, | |
| data: Path, | |
| n: int, | |
| L: int, | |
| rng: np.random.Generator, | |
| ) -> np.ndarray: | |
| """Load real test grids, preferring a dedicated test file over the val split. | |
| Parameters | |
| ---------- | |
| test_data : optional path to a standalone test .npy file (N, L, L) int8 {-1,+1} | |
| data : path to the main spins.npy (used only if test_data is None) | |
| """ | |
| if test_data is not None: | |
| spins = np.load(test_data) # (N, L, L) int8 | |
| tokens = (spins.astype(np.int32) + 1) // 2 # β {0, 1} | |
| rows, cols = snake_order(L) | |
| tokens = tokens[:, rows, cols] # (N, LΒ²) | |
| else: | |
| _, tokens = load_ising_data(data) # val split of spins.npy | |
| n = min(n, len(tokens)) | |
| idx = rng.choice(len(tokens), size=n, replace=False) | |
| return tokens_to_grids(tokens[idx], L) # (n, L, L), values Β±1 | |
| def parse_args(): | |
| p = argparse.ArgumentParser( | |
| description="Compare generated vs real Ising samples via physical observables." | |
| ) | |
| p.add_argument("--checkpoint", type=Path, required=True, | |
| help="Path to the .eqx checkpoint file.") | |
| p.add_argument("--data", type=Path, | |
| default=Path(__file__).parent / "spins.npy", | |
| help="Path to spins.npy (default: ./spins.npy). " | |
| "Used only if --test-data is not provided.") | |
| p.add_argument("--test-data", type=Path, | |
| default=Path(__file__).parent / "spins_test.npy", | |
| help="Dedicated held-out test set (.npy, NΓLΓL int8 {-1,+1}). " | |
| "Takes priority over the val split of --data.") | |
| p.add_argument("--num-samples", type=int, default=50, | |
| help="Number of samples to compare (default: 50).") | |
| p.add_argument("--samples-file", type=Path, default=None, | |
| help="Optional .npy of pre-generated {-1,+1} grids (N,L,L) " | |
| "from 'sample.py --output'. Skips generation entirely.") | |
| p.add_argument("--seed", type=int, default=0) | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| L = gen_config["lattice_size"] | |
| rng = np.random.default_rng(args.seed) | |
| # ββ Real samples (test split) βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Prefer spins_test.npy; fall back to val split of spins.npy. | |
| test_path = args.test_data if (args.test_data and args.test_data.exists()) else None | |
| if test_path: | |
| print(f"Loading test data from {test_path} β¦") | |
| else: | |
| print("Loading test data from val split of spins.npy β¦") | |
| n = args.num_samples | |
| real_grids = load_test_grids(test_path, args.data, n, L, rng) | |
| n = len(real_grids) # may be capped by dataset size | |
| # ββ Generated samples βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if args.samples_file is not None: | |
| print(f"Loading pre-generated samples from {args.samples_file} β¦") | |
| gen_grids = np.load(args.samples_file).astype(np.int8)[:n] | |
| if gen_grids.shape[1:] != (L, L): | |
| raise ValueError( | |
| f"samples-file grid shape {gen_grids.shape[1:]} != ({L},{L})" | |
| ) | |
| n = min(n, len(gen_grids)) | |
| real_grids = real_grids[:n] | |
| else: | |
| print(f"Loading checkpoint from {args.checkpoint} β¦") | |
| model = load_checkpoint(args.checkpoint) | |
| key = jax.random.PRNGKey(args.seed) | |
| gen_grids = generate_grids(model, n, key, L) # (n, L, L), values Β±1 | |
| print(f"\nL = {L} | N = {n} samples per group | T_C = {T_C:.6f}\n") | |
| # ββ Feature matrices ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| X_real = compute_feature_matrix(real_grids, desc="Features: real ") | |
| X_gen = compute_feature_matrix(gen_grids, desc="Features: generated ") | |
| # ββ Per-feature comparison table ββββββββββββββββββββββββββββββββββββββββββ | |
| print("Per-feature statistics (z-score = ΞΞΌ / Ο_real; '<' marks |z| > 1)\n") | |
| print_feature_table(X_real, X_gen) | |
| # ββ Ensemble statistics βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Ensemble statistics\n") | |
| print_ensemble_table(ensemble_stats(X_real), ensemble_stats(X_gen)) | |
| # ββ Mahalanobis distance ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| D, z = mahalanobis_distance(X_real, X_gen) | |
| print(f"Mahalanobis distance D = {D:.4f}") | |
| print( " (D measures how many 'std-devs' the generated feature mean sits") | |
| print( " from the real distribution in the decorrelated feature space.)") | |
| print() | |
| print(" Top deviating features:") | |
| order = np.argsort(np.abs(z))[::-1] | |
| for i in order[:5]: | |
| print(f" {FEATURE_NAMES[i]:<11} z = {z[i]:+.3f}") | |
| if __name__ == "__main__": | |
| main() | |