AbstractPhil's picture
Rename 8_probe_ft3.py to 008_probe_ft3.py
c62978d verified
"""
cell_g_class_probe_v3.py β€” three-way geometric probe
Tests the same geometric battery of metrics on three batteries:
H2a: Q-rank02 (D=4, V=32, 40K params, 1000 batches Adam)
G-Cand: Q-rank09 (D=3, V=32, 29K params, 1000 batches Adam)
h2-64: single-noise gaussian battery (D=8, V=64, 57K params, 10 epochs)
Key question: is the antipodal+rotational structure found in G-Cand a
property of D=3 specifically, or a property of LOW-band attractors at
ANY D? h2-64 has D=8 which sits in LOW band naturally (CV ~0.21).
Predicted outcomes:
- h2-64 looks like H2 (uniform sphere, stable rows): G-class is D=3-specific
- h2-64 looks like G (antipodal pairs, rotating frame): G-class is the
universal LOW-band character; H2a is the OUTLIER for being so static
- h2-64 looks like neither (some third pattern): D=8 has its own
geometric character we haven't seen yet
Loading h2-64 from `loaded` if defined in session, else fetches from HF.
"""
import json
import math
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
CKPT_DIR = Path("/content/phaseQ_reports")
RANK09_CKPT = CKPT_DIR / "Q_rank09_h64_V32_D3_dp0_nx0_adam" / "epoch_1_checkpoint.pt"
RANK02_CKPT = CKPT_DIR / "Q_rank02_h64_V32_D4_dp0_nx0_adam" / "epoch_1_checkpoint.pt"
OUTPUT_PLOT = CKPT_DIR / "g_class_probe_v3.png"
OUTPUT_JSON = CKPT_DIR / "g_class_probe_v3.json"
# ════════════════════════════════════════════════════════════════════
# Loading
# ════════════════════════════════════════════════════════════════════
def load_qsweep_model(variant_str, ckpt_path):
"""Load Q-sweep model (rank02 or rank09)."""
cfgs = get_phaseQ_configs()
cfg_dict = next(c for c in cfgs if variant_str in c['variant'])
cfg = build_run_config(cfg_dict)
overrides = cfg_dict['overrides']
model = PatchSVAE_F_Ablation(
matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
hidden=cfg.hidden, depth=cfg.depth,
n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
max_alpha=overrides.get('max_alpha', cfg.max_alpha),
alpha_init=cfg.alpha_init,
activation=overrides.get('activation', 'gelu'),
row_norm=overrides.get('row_norm', 'sphere'),
svd_mode=overrides.get('svd', 'fp64'),
linear_readout=overrides.get('linear_readout', False),
match_params=overrides.get('match_params', True),
init_scheme=overrides.get('init', 'orthogonal'),
)
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
state_dict = (
ckpt.get('model_state')
or ckpt.get('model_state_dict')
or ckpt.get('state_dict')
or ckpt
)
model.load_state_dict(state_dict)
model.eval()
return model, cfg
def load_h2_64_battery(battery_idx=0, phase='final'):
"""Get one battery from the h2-64 array.
Tries `loaded` from globals first (already in Colab session),
falls back to AutoModel.from_pretrained.
Returns (bank_module, V, D, patch_size, img_size).
"""
array_model = globals().get('loaded')
if array_model is None:
print(f" `loaded` not found, fetching from HF...")
# Importing geolip_svae.arrays auto-registers BatteryArrayConfig
# with HF Auto* β€” without this, model_type='battery_array' is unknown.
import geolip_svae.arrays # noqa: F401
from transformers import AutoModel
array_model = AutoModel.from_pretrained(
"AbstractPhil/geolip-svae-h2-64")
print(f" Loaded h2-64 from HF")
else:
print(f" Using `loaded` from global session")
# Get the specific battery bank
bank = array_model.bank(battery_idx, phase)
bank.eval()
# Get architecture from config
cfg_dict = array_model.config.batteries[battery_idx]
print(f" Battery {battery_idx} ({phase}): "
f"subgroup={cfg_dict.get('subgroup')}, "
f"variant={cfg_dict.get('variant')}, "
f"noise_types={cfg_dict.get('noise_types')}")
# Architecture is uniform across h2-64 batteries
V = 64
D = 8
patch_size = 2
img_size = 64
return bank, V, D, patch_size, img_size
# ════════════════════════════════════════════════════════════════════
# Collect M rows
# ════════════════════════════════════════════════════════════════════
def collect_per_sample_M(model, V, D, patch_size, img_size,
n_batches=8, batch_size=64,
is_h2_64_bank=False):
"""Collect [n_samples, V, D] M tensors from gaussian inputs."""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
ds = OmegaNoiseDataset(
size=n_batches * batch_size,
img_size=img_size,
allowed_types=[0]) # gaussian
loader = torch.utils.data.DataLoader(
ds, batch_size=batch_size, shuffle=False)
all_M = []
with torch.no_grad():
for imgs, _ in loader:
imgs = imgs.to(device)
out = model(imgs)
# Both PatchSVAE and h2-64 banks return dict with 'svd' or
# similar β€” the M tensor is at out['svd']['M'][:, 0]
if 'svd' in out and 'M' in out['svd']:
M_patch0 = out['svd']['M'][:, 0] # [B, V, D]
elif 'M' in out:
M_patch0 = out['M'][:, 0]
else:
# Fall back: try to access via internal encode_patches
from johanna_F_trainer import extract_patches
patches = extract_patches(imgs, patch_size)
enc = model.encode_patches(patches)
M_patch0 = enc['M'][:, 0]
all_M.append(M_patch0.cpu())
return torch.cat(all_M, dim=0).numpy() # [n_samples, V, D]
# ════════════════════════════════════════════════════════════════════
# Tests (carry over from v2)
# ════════════════════════════════════════════════════════════════════
def test_sphere_norm(all_M):
row_norms = np.linalg.norm(all_M, axis=2)
return {
'min': float(row_norms.min()),
'max': float(row_norms.max()),
'mean': float(row_norms.mean()),
'std': float(row_norms.std()),
'sphere_normed': bool(
abs(row_norms.mean() - 1.0) < 0.05 and row_norms.std() < 0.05),
}
def test_row_stability(all_M):
mean_dirs = all_M.mean(axis=0)
mean_dir_norms = np.linalg.norm(mean_dirs, axis=1)
return {
'mean': float(mean_dir_norms.mean()),
'min': float(mean_dir_norms.min()),
'max': float(mean_dir_norms.max()),
'std': float(mean_dir_norms.std()),
'mean_dir_norms': mean_dir_norms.tolist(),
}
def test_per_sample_clustering(all_M, k_test=5, n_samples=20):
silhouettes = []
for i in range(min(n_samples, all_M.shape[0])):
M = all_M[i]
try:
km = KMeans(n_clusters=k_test, n_init=10, random_state=42)
labels = km.fit_predict(M)
if len(set(labels)) >= 2:
sil = silhouette_score(M, labels)
silhouettes.append(sil)
except Exception:
pass
silhouettes = np.array(silhouettes)
return {
'k_tested': k_test,
'mean': float(silhouettes.mean()) if len(silhouettes) else None,
'std': float(silhouettes.std()) if len(silhouettes) else None,
'silhouettes_per_sample': silhouettes.tolist(),
}
def test_angular_distribution(all_M):
all_rows = all_M.reshape(-1, all_M.shape[-1])
norms = np.linalg.norm(all_rows, axis=1, keepdims=True)
unit_rows = all_rows / np.clip(norms, 1e-12, None)
n_subset = min(500, unit_rows.shape[0])
idx = np.random.RandomState(42).choice(
unit_rows.shape[0], n_subset, replace=False)
subset = unit_rows[idx]
cosines = subset @ subset.T
triu_idx = np.triu_indices(n_subset, k=1)
pairwise_cos = cosines[triu_idx]
pairwise_angles = np.arccos(np.clip(pairwise_cos, -1, 1))
return {
'mean_angle': float(pairwise_angles.mean()),
'median_angle': float(np.median(pairwise_angles)),
'fraction_near_zero': float((pairwise_angles < 0.5).mean()),
'fraction_near_pi': float((pairwise_angles > math.pi - 0.5).mean()),
'fraction_near_perp': float(
((pairwise_angles > math.pi/2 - 0.3) &
(pairwise_angles < math.pi/2 + 0.3)).mean()),
'pairwise_angles_subset': pairwise_angles[:200].tolist(),
}
def test_antipodal(all_M):
mean_dirs = all_M.mean(axis=0)
norms = np.linalg.norm(mean_dirs, axis=1, keepdims=True)
unit_dirs = mean_dirs / np.clip(norms, 1e-12, None)
cosines = unit_dirs @ unit_dirs.T
np.fill_diagonal(cosines, 1.0)
most_anti_cos = cosines.min(axis=1)
n_pairs = (most_anti_cos < -0.9).sum() // 2
return {
'min_cos': float(most_anti_cos.min()),
'mean_cos': float(most_anti_cos.mean()),
'fraction_with_antipode': float((most_anti_cos < -0.9).mean()),
'estimated_pairs': int(n_pairs),
'max_possible_pairs': all_M.shape[1] // 2,
}
def test_effective_rank(all_M):
M_avg = all_M.mean(axis=0)
sv = np.linalg.svd(M_avg, compute_uv=False)
sv_norm = sv / sv.sum()
erank = math.exp(-(sv_norm * np.log(sv_norm + 1e-12)).sum())
return {
'singular_values': sv.tolist(),
'normalized_SV': sv_norm.tolist(),
'effective_rank': float(erank),
'D': int(all_M.shape[2]),
'utilization': float(erank / all_M.shape[2]),
'top1_share': float(sv_norm[0]),
}
def run_all_tests(all_M, label):
print(f"\n[{label}]")
print(f" Shape: {all_M.shape}")
sphere = test_sphere_norm(all_M)
print(f" Sphere-norm: mean={sphere['mean']:.4f}, "
f"std={sphere['std']:.4f} β†’ {'YES' if sphere['sphere_normed'] else 'NO'}")
stability = test_row_stability(all_M)
print(f" Row stability: mean={stability['mean']:.3f}, "
f"range=[{stability['min']:.3f}, {stability['max']:.3f}]")
cluster = test_per_sample_clustering(all_M)
if cluster['mean'] is not None:
print(f" Cluster (k=5): silhouette mean={cluster['mean']:.3f}, "
f"std={cluster['std']:.3f}")
angular = test_angular_distribution(all_M)
print(f" Angular: mean={angular['mean_angle']:.3f} "
f"(uniform=Ο€/2={math.pi/2:.3f})")
print(f" near-perp: {angular['fraction_near_perp']:.3f}, "
f"near-Ο€: {angular['fraction_near_pi']:.3f}")
antipodal = test_antipodal(all_M)
print(f" Antipodal: {antipodal['estimated_pairs']}/"
f"{antipodal['max_possible_pairs']} pairs, "
f"frac with antipode={antipodal['fraction_with_antipode']:.3f}")
erank = test_effective_rank(all_M)
print(f" Effective rank: {erank['effective_rank']:.2f} of {erank['D']} "
f"({erank['utilization']*100:.0f}% utilization)")
return {
'sphere_norm': sphere,
'stability': stability,
'clustering': cluster,
'angular': angular,
'antipodal': antipodal,
'rank': erank,
}
# ════════════════════════════════════════════════════════════════════
# Composite character classification
# ════════════════════════════════════════════════════════════════════
def classify_battery_character(results):
"""Determine if battery is H2-like (sphere-solver) or G-like
(rotating-antipodal) or something else."""
stab = results['stability']['mean']
antipodal_frac = results['antipodal']['fraction_with_antipode']
cluster_sil = results['clustering']['mean']
rank_util = results['rank']['utilization']
# H2-like: high stability, low antipodal fraction, full rank
is_h2_like = (
stab > 0.85 and
antipodal_frac < 0.55 and
rank_util > 0.95
)
# G-like: low stability, high antipodal fraction
is_g_like = (
stab < 0.65 and
antipodal_frac > 0.80
)
# Hybrid: somewhere in between
if is_h2_like:
return f"H2-LIKE (static sphere-solver)"
elif is_g_like:
return f"G-LIKE (rotating antipodal frame)"
elif stab < 0.65 and antipodal_frac < 0.55:
return f"DIFFUSE (low stability, no antipodal structure)"
else:
return (f"HYBRID (stab={stab:.2f}, antipodal_frac="
f"{antipodal_frac:.2f})")
# ════════════════════════════════════════════════════════════════════
# Main
# ════════════════════════════════════════════════════════════════════
def main():
print("=" * 70)
print("Loading three batteries for comparative analysis")
print("=" * 70)
print("\n[1/3] H2a (Q-rank02, D=4, 1000-batch Adam)")
h2_model, h2_cfg = load_qsweep_model('rank02', RANK02_CKPT)
print(f" V={h2_cfg.matrix_v}, D={h2_cfg.D}, "
f"params={sum(p.numel() for p in h2_model.parameters()):,}")
print("\n[2/3] G-Class candidate (Q-rank09, D=3, 1000-batch Adam)")
g_model, g_cfg = load_qsweep_model('rank09', RANK09_CKPT)
print(f" V={g_cfg.matrix_v}, D={g_cfg.D}, "
f"params={sum(p.numel() for p in g_model.parameters()):,}")
print("\n[3/3] h2-64 single-noise gaussian battery (D=8, 10 epochs converged)")
h264_bank, h264_V, h264_D, h264_ps, h264_img = load_h2_64_battery(
battery_idx=0, phase='final')
print(f" V={h264_V}, D={h264_D}, patch_size={h264_ps}, img_size={h264_img}")
# ════════════════════════════════════════════════════════════════
# Collect M rows
# ════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("Collecting M rows (gaussian inputs, 512 samples each)")
print("=" * 70)
print("\n H2a...")
all_M_h2 = collect_per_sample_M(
h2_model, h2_cfg.matrix_v, h2_cfg.D,
h2_cfg.patch_size, h2_cfg.img_size)
print(" G-Cand...")
all_M_g = collect_per_sample_M(
g_model, g_cfg.matrix_v, g_cfg.D,
g_cfg.patch_size, g_cfg.img_size)
print(" h2-64 gaussian...")
all_M_h264 = collect_per_sample_M(
h264_bank, h264_V, h264_D, h264_ps, h264_img,
is_h2_64_bank=True)
# ════════════════════════════════════════════════════════════════
# Run tests on each
# ════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("GEOMETRIC ANALYSIS")
print("=" * 70)
results_h2 = run_all_tests(all_M_h2, "H2a (D=4, 1000-batch Adam)")
results_g = run_all_tests(all_M_g, "G-Cand (D=3, 1000-batch Adam)")
results_h264 = run_all_tests(
all_M_h264, "h2-64 gaussian (D=8, 10 epochs)")
# ════════════════════════════════════════════════════════════════
# Side-by-side comparison
# ════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("THREE-WAY COMPARISON")
print("=" * 70)
headers = f"{'Metric':<32} {'H2a (D=4)':>12} {'G-Cand (D=3)':>14} {'h2-64 (D=8)':>14}"
print(f"\n {headers}")
print(" " + "-" * len(headers))
rows = [
('Effective rank',
results_h2['rank']['effective_rank'],
results_g['rank']['effective_rank'],
results_h264['rank']['effective_rank'],
'.2f'),
('Dim utilization (%)',
results_h2['rank']['utilization'] * 100,
results_g['rank']['utilization'] * 100,
results_h264['rank']['utilization'] * 100,
'.0f'),
('Row stability',
results_h2['stability']['mean'],
results_g['stability']['mean'],
results_h264['stability']['mean'],
'.3f'),
('Per-sample silhouette (k=5)',
results_h2['clustering']['mean'] or 0,
results_g['clustering']['mean'] or 0,
results_h264['clustering']['mean'] or 0,
'.3f'),
('Mean pairwise angle (rad)',
results_h2['angular']['mean_angle'],
results_g['angular']['mean_angle'],
results_h264['angular']['mean_angle'],
'.3f'),
('Antipodal pair fraction',
results_h2['antipodal']['fraction_with_antipode'],
results_g['antipodal']['fraction_with_antipode'],
results_h264['antipodal']['fraction_with_antipode'],
'.3f'),
('Estimated antipodal pairs',
results_h2['antipodal']['estimated_pairs'],
results_g['antipodal']['estimated_pairs'],
results_h264['antipodal']['estimated_pairs'],
'd'),
]
for row in rows:
name, h2v, gv, h264v, fmt = row
if fmt == 'd':
print(f" {name:<32} {h2v:>12d} {gv:>14d} {h264v:>14d}")
else:
print(f" {name:<32} {h2v:>12{fmt}} {gv:>14{fmt}} {h264v:>14{fmt}}")
print()
char_h2 = classify_battery_character(results_h2)
char_g = classify_battery_character(results_g)
char_h264 = classify_battery_character(results_h264)
print(f" Character verdict:")
print(f" H2a: {char_h2}")
print(f" G-Cand: {char_g}")
print(f" h2-64: {char_h264}")
# Headline conclusion
print("\n" + "=" * 70)
print("CONCLUSION")
print("=" * 70)
if "G-LIKE" in char_h264:
print(" h2-64 (D=8, fully converged) shows G-CLASS character.")
print(" β†’ The antipodal+rotational structure is NOT D=3-specific.")
print(" β†’ It's the LOW-band attractor's natural geometry.")
print(" β†’ H2a (D=4 at HIGH band) is the OUTLIER β€” its sphere-solver")
print(" rigidity is HIGH-band-specific, not the universal pattern.")
elif "H2-LIKE" in char_h264:
print(" h2-64 (D=8, fully converged) shows H2 sphere-solver character.")
print(" β†’ G-Class at D=3 is genuinely different from sphere-solvers.")
print(" β†’ D=3 specifically can't form a stable static 32-row arrangement,")
print(" so it falls into the rotating-antipodal regime.")
print(" β†’ Higher D recovers static sphere-solver behavior even in LOW band.")
elif "HYBRID" in char_h264 or "DIFFUSE" in char_h264:
print(" h2-64 (D=8) shows mixed character β€” partial G-like features.")
print(" β†’ Possible spectrum: HIGH-band β†’ static sphere (H2),")
print(" LOW-band β†’ progressively more antipodal as D decreases.")
print(" β†’ D=8 sits in transition; D=3 is fully G-class; D=4 HIGH is fully H2.")
all_results = {
'h2a': results_h2,
'g_class_candidate': results_g,
'h2_64_gaussian': results_h264,
'characters': {
'h2a': char_h2,
'g_class': char_g,
'h2_64': char_h264,
},
}
with open(OUTPUT_JSON, 'w') as f:
json.dump(all_results, f, indent=2, default=str)
print(f"\n Saved: {OUTPUT_JSON}")
# Plot
plot_three_way(all_M_h2, all_M_g, all_M_h264,
results_h2, results_g, results_h264, OUTPUT_PLOT)
print(f" Saved: {OUTPUT_PLOT}")
return all_results
def plot_three_way(M_h2, M_g, M_h264, r_h2, r_g, r_h264, output_path):
"""6-panel comparison figure: 3 batteries Γ— 2 metrics each."""
fig = plt.figure(figsize=(18, 14))
# Row 1: Single-sample row scatters (project to first 3 dims)
ax1 = fig.add_subplot(3, 3, 1, projection='3d')
s = M_h2[0]
ax1.scatter(s[:, 0], s[:, 1], s[:, 2], c=np.arange(len(s)),
cmap='viridis', s=80, edgecolors='black', linewidths=0.5)
ax1.set_title(f'H2a (D=4) β€” single sample\nrows projected to first 3 dims')
ax2 = fig.add_subplot(3, 3, 2, projection='3d')
s = M_g[0]
ax2.scatter(s[:, 0], s[:, 1], s[:, 2], c=np.arange(len(s)),
cmap='viridis', s=80, edgecolors='black', linewidths=0.5)
ax2.set_title(f'G-Cand (D=3) β€” single sample\nfull native dims')
ax3 = fig.add_subplot(3, 3, 3, projection='3d')
s = M_h264[0]
ax3.scatter(s[:, 0], s[:, 1], s[:, 2], c=np.arange(len(s)),
cmap='viridis', s=80, edgecolors='black', linewidths=0.5)
ax3.set_title(f'h2-64 gaussian (D=8) β€” single sample\nrows projected to first 3 dims')
# Row 2: Per-row stability sorted (descending)
ax4 = fig.add_subplot(3, 3, 4)
ax4.plot(sorted(r_h2['stability']['mean_dir_norms'], reverse=True),
'o-', color='blue', markersize=4)
ax4.set_title(f"H2a row stability\nmean={r_h2['stability']['mean']:.3f}")
ax4.set_xlabel('Row index (sorted)')
ax4.set_ylabel('Mean direction norm')
ax4.set_ylim([0, 1.05])
ax4.grid(alpha=0.3)
ax5 = fig.add_subplot(3, 3, 5)
ax5.plot(sorted(r_g['stability']['mean_dir_norms'], reverse=True),
'o-', color='red', markersize=4)
ax5.set_title(f"G-Cand row stability\nmean={r_g['stability']['mean']:.3f}")
ax5.set_xlabel('Row index (sorted)')
ax5.set_ylabel('Mean direction norm')
ax5.set_ylim([0, 1.05])
ax5.grid(alpha=0.3)
ax6 = fig.add_subplot(3, 3, 6)
ax6.plot(sorted(r_h264['stability']['mean_dir_norms'], reverse=True),
'o-', color='green', markersize=4)
ax6.set_title(f"h2-64 row stability\nmean={r_h264['stability']['mean']:.3f}")
ax6.set_xlabel('Row index (sorted)')
ax6.set_ylabel('Mean direction norm')
ax6.set_ylim([0, 1.05])
ax6.grid(alpha=0.3)
# Row 3: Pairwise angle distributions
ax7 = fig.add_subplot(3, 3, 7)
ax7.hist(r_h2['angular']['pairwise_angles_subset'], bins=30,
color='blue', alpha=0.7, density=True)
ax7.axvline(math.pi/2, color='black', linestyle='--', alpha=0.5)
ax7.set_title(f"H2a pairwise angles\nmean={r_h2['angular']['mean_angle']:.3f}")
ax7.set_xlabel('Angle (radians)')
ax8 = fig.add_subplot(3, 3, 8)
ax8.hist(r_g['angular']['pairwise_angles_subset'], bins=30,
color='red', alpha=0.7, density=True)
ax8.axvline(math.pi/2, color='black', linestyle='--', alpha=0.5)
ax8.set_title(f"G-Cand pairwise angles\nmean={r_g['angular']['mean_angle']:.3f}")
ax8.set_xlabel('Angle (radians)')
ax9 = fig.add_subplot(3, 3, 9)
ax9.hist(r_h264['angular']['pairwise_angles_subset'], bins=30,
color='green', alpha=0.7, density=True)
ax9.axvline(math.pi/2, color='black', linestyle='--', alpha=0.5)
ax9.set_title(f"h2-64 pairwise angles\nmean={r_h264['angular']['mean_angle']:.3f}")
ax9.set_xlabel('Angle (radians)')
plt.tight_layout()
plt.savefig(output_path, dpi=120, bbox_inches='tight')
plt.show()
if __name__ == '__main__':
results = main()