| """ |
| cell_p_class_probe_v2.py β deeper geometric probe for P-Class |
| |
| Addresses limitations of v1's averaged-M analysis: |
| 1. Verify sphere-norm is enforced per-sample (M rows should be unit-length |
| per-sample, even if they average to sub-unit across samples) |
| 2. Test structure on PER-SAMPLE M, not averaged |
| 3. Check if the 5-cluster finding from v1 is consistent or sample-dependent |
| 4. Spherical structure analysis: project rows to SΒ², test for angular |
| distribution structure (uniform? clustered? band-like?) |
| 5. Reconstruct what the H2 sphere-solver looks like for comparison |
| |
| Key question: are the 32 rows really clustered, or does each sample have |
| its own spread of 32 rows on SΒ² that AVERAGE to look clustered? |
| """ |
|
|
| import json |
| import math |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| from mpl_toolkits.mplot3d import Axes3D |
| from sklearn.cluster import KMeans |
| from sklearn.metrics import silhouette_score |
|
|
|
|
| CKPT_DIR = Path("/content/phaseQ_reports") |
| RANK09_CKPT = CKPT_DIR / "Q_rank09_h64_V32_D3_dp0_nx0_adam" / "epoch_1_checkpoint.pt" |
| RANK02_CKPT = CKPT_DIR / "Q_rank02_h64_V32_D4_dp0_nx0_adam" / "epoch_1_checkpoint.pt" |
| OUTPUT_PLOT = CKPT_DIR / "p_rank09_probe_v2.png" |
| OUTPUT_JSON = CKPT_DIR / "p_rank09_probe_v2.json" |
|
|
|
|
| def load_model(variant_str, ckpt_path): |
| cfgs = get_phaseQ_configs() |
| cfg_dict = next(c for c in cfgs if variant_str in c['variant']) |
| cfg = build_run_config(cfg_dict) |
| overrides = cfg_dict['overrides'] |
|
|
| model = PatchSVAE_F_Ablation( |
| matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size, |
| hidden=cfg.hidden, depth=cfg.depth, |
| n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads, |
| max_alpha=overrides.get('max_alpha', cfg.max_alpha), |
| alpha_init=cfg.alpha_init, |
| activation=overrides.get('activation', 'gelu'), |
| row_norm=overrides.get('row_norm', 'sphere'), |
| svd_mode=overrides.get('svd', 'fp64'), |
| linear_readout=overrides.get('linear_readout', False), |
| match_params=overrides.get('match_params', True), |
| init_scheme=overrides.get('init', 'orthogonal'), |
| ) |
|
|
| ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) |
| state_dict = ( |
| ckpt.get('model_state') |
| or ckpt.get('model_state_dict') |
| or ckpt.get('state_dict') |
| or ckpt |
| ) |
| model.load_state_dict(state_dict) |
| model.eval() |
| return model, cfg |
|
|
|
|
| def collect_per_sample_M(model, cfg, n_batches=8, batch_size=64): |
| """Same as v1 but does NOT average β returns per-sample M tensors.""" |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model = model.to(device) |
|
|
| ds = OmegaNoiseDataset( |
| size=n_batches * batch_size, |
| img_size=cfg.img_size, |
| allowed_types=[0]) |
| loader = torch.utils.data.DataLoader( |
| ds, batch_size=batch_size, shuffle=False) |
|
|
| all_M = [] |
| with torch.no_grad(): |
| for imgs, _ in loader: |
| imgs = imgs.to(device) |
| out = model(imgs) |
| M_patch0 = out['svd']['M'][:, 0] |
| all_M.append(M_patch0.cpu()) |
|
|
| return torch.cat(all_M, dim=0).numpy() |
|
|
|
|
| |
| |
| |
|
|
| def test_sphere_norm(all_M, label): |
| """Verify that per-sample rows are unit-length (sphere-normed).""" |
| print(f"\n[{label}] PER-SAMPLE sphere-norm verification:") |
|
|
| |
| row_norms = np.linalg.norm(all_M, axis=2) |
|
|
| print(f" Per-sample row norms:") |
| print(f" overall min: {row_norms.min():.4f}") |
| print(f" overall max: {row_norms.max():.4f}") |
| print(f" overall mean: {row_norms.mean():.4f}") |
| print(f" overall std: {row_norms.std():.4f}") |
|
|
| is_normed = ( |
| abs(row_norms.mean() - 1.0) < 0.05 and |
| row_norms.std() < 0.05 |
| ) |
| print(f" Sphere-norm enforced per-sample: {is_normed}") |
|
|
| return { |
| 'row_norms_min': float(row_norms.min()), |
| 'row_norms_max': float(row_norms.max()), |
| 'row_norms_mean': float(row_norms.mean()), |
| 'row_norms_std': float(row_norms.std()), |
| 'sphere_normed_per_sample': bool(is_normed), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def test_row_stability(all_M, label): |
| """For each row index i in [0, V), how much does row i vary across |
| samples? If rows are stable (each row index always points the same |
| direction), per-sample structure β averaged structure. If unstable, |
| averaging blurs structure.""" |
| print(f"\n[{label}] PER-ROW stability across samples:") |
|
|
| |
| |
| n_samples, V, D = all_M.shape |
|
|
| |
| mean_dirs = all_M.mean(axis=0) |
| mean_dir_norms = np.linalg.norm(mean_dirs, axis=1) |
|
|
| |
| |
| |
| |
| print(f" Mean direction norms (concentration of row[i] across samples):") |
| print(f" min: {mean_dir_norms.min():.4f} (most variable row)") |
| print(f" max: {mean_dir_norms.max():.4f} (most stable row)") |
| print(f" mean: {mean_dir_norms.mean():.4f}") |
|
|
| return { |
| 'mean_dir_norms_min': float(mean_dir_norms.min()), |
| 'mean_dir_norms_max': float(mean_dir_norms.max()), |
| 'mean_dir_norms_mean': float(mean_dir_norms.mean()), |
| 'mean_dirs': mean_dirs.tolist(), |
| 'mean_dir_norms': mean_dir_norms.tolist(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def test_per_sample_clustering(all_M, k_test=5, n_samples_to_check=20): |
| """For each of n_samples_to_check samples, run k-means clustering on its |
| own 32 rows. If we consistently get strong clusters at the same k, the |
| structure is intrinsic to each sample. If silhouette varies wildly, the |
| averaged result was an artifact.""" |
| print(f"\nPER-SAMPLE k=5 clustering (testing first {n_samples_to_check} samples):") |
|
|
| silhouettes = [] |
| for i in range(min(n_samples_to_check, all_M.shape[0])): |
| M = all_M[i] |
| try: |
| km = KMeans(n_clusters=k_test, n_init=10, random_state=42) |
| labels = km.fit_predict(M) |
| if len(set(labels)) >= 2: |
| sil = silhouette_score(M, labels) |
| silhouettes.append(sil) |
| except Exception: |
| pass |
|
|
| silhouettes = np.array(silhouettes) |
| print(f" Silhouette across samples (k={k_test}):") |
| print(f" mean: {silhouettes.mean():.3f}") |
| print(f" std: {silhouettes.std():.3f}") |
| print(f" range: [{silhouettes.min():.3f}, {silhouettes.max():.3f}]") |
|
|
| return { |
| 'k_tested': k_test, |
| 'silhouettes_per_sample': silhouettes.tolist(), |
| 'mean_silhouette': float(silhouettes.mean()), |
| 'std_silhouette': float(silhouettes.std()), |
| 'min_silhouette': float(silhouettes.min()) if len(silhouettes) > 0 else None, |
| 'max_silhouette': float(silhouettes.max()) if len(silhouettes) > 0 else None, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def test_angular_distribution(all_M, label): |
| """Project all per-sample row vectors to unit sphere (re-normalize), |
| then look at distribution of pairwise angles. Uniform distribution gives |
| a specific angular density. Clustered gives bimodal angles. Polar / band |
| structures give characteristic patterns.""" |
| print(f"\n[{label}] ANGULAR DISTRIBUTION:") |
|
|
| |
| all_rows = all_M.reshape(-1, all_M.shape[-1]) |
| norms = np.linalg.norm(all_rows, axis=1, keepdims=True) |
| unit_rows = all_rows / np.clip(norms, 1e-12, None) |
|
|
| |
| n_subset = min(500, unit_rows.shape[0]) |
| idx = np.random.RandomState(42).choice(unit_rows.shape[0], n_subset, replace=False) |
| subset = unit_rows[idx] |
|
|
| |
| cosines = subset @ subset.T |
| triu_idx = np.triu_indices(n_subset, k=1) |
| pairwise_cos = cosines[triu_idx] |
| pairwise_angles = np.arccos(np.clip(pairwise_cos, -1, 1)) |
|
|
| |
| |
| |
|
|
| mean_angle = float(pairwise_angles.mean()) |
| median_angle = float(np.median(pairwise_angles)) |
| expected_uniform_mean = math.pi / 2 |
|
|
| print(f" Pairwise angle stats (radians):") |
| print(f" mean: {mean_angle:.3f} (uniform β Ο/2 = 1.571)") |
| print(f" median: {median_angle:.3f}") |
| print(f" deviation from uniform mean: {abs(mean_angle - expected_uniform_mean):.3f}") |
|
|
| |
| |
| |
|
|
| near_zero = (pairwise_angles < 0.5).sum() / len(pairwise_angles) |
| near_pi = (pairwise_angles > math.pi - 0.5).sum() / len(pairwise_angles) |
| near_perp = ((pairwise_angles > math.pi / 2 - 0.3) & |
| (pairwise_angles < math.pi / 2 + 0.3)).sum() / len(pairwise_angles) |
|
|
| print(f" fraction near 0 (parallel): {near_zero:.3f}") |
| print(f" fraction near Ο (antiparallel): {near_pi:.3f}") |
| print(f" fraction near Ο/2 (perpendicular): {near_perp:.3f}") |
|
|
| return { |
| 'mean_angle': mean_angle, |
| 'median_angle': median_angle, |
| 'expected_uniform_mean': expected_uniform_mean, |
| 'fraction_near_zero': float(near_zero), |
| 'fraction_near_pi': float(near_pi), |
| 'fraction_near_perp': float(near_perp), |
| 'pairwise_angles_subset': pairwise_angles[:200].tolist(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def test_antipodal(all_M, label): |
| """Check if each row has a near-antipodal partner. If 32 rows form |
| 16 antipodal pairs, that's a different geometric structure than |
| 32 independent points.""" |
| print(f"\n[{label}] ANTIPODAL STRUCTURE:") |
|
|
| mean_dirs = all_M.mean(axis=0) |
| norms = np.linalg.norm(mean_dirs, axis=1, keepdims=True) |
| unit_dirs = mean_dirs / np.clip(norms, 1e-12, None) |
|
|
| |
| cosines = unit_dirs @ unit_dirs.T |
| np.fill_diagonal(cosines, 1.0) |
| most_anti_cos = cosines.min(axis=1) |
|
|
| |
| n_antipodal_pairs = (most_anti_cos < -0.9).sum() // 2 |
|
|
| print(f" Most-antipodal cos for each row:") |
| print(f" min: {most_anti_cos.min():.4f}") |
| print(f" mean: {most_anti_cos.mean():.4f}") |
| print(f" fraction with antipode (cos < -0.9): " |
| f"{(most_anti_cos < -0.9).mean():.3f}") |
| print(f" Estimated antipodal pairs: {n_antipodal_pairs} / " |
| f"{all_M.shape[1]//2} possible") |
|
|
| return { |
| 'most_antipodal_cosines_min': float(most_anti_cos.min()), |
| 'most_antipodal_cosines_mean': float(most_anti_cos.mean()), |
| 'fraction_with_antipode': float((most_anti_cos < -0.9).mean()), |
| 'estimated_antipodal_pairs': int(n_antipodal_pairs), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def comparison_test(all_M_p, all_M_h2): |
| """Side-by-side: P-Class (D=3) vs H2a (D=4). What's the actual |
| structural difference?""" |
| print("\n" + "β" * 70) |
| print("DIRECT COMPARISON: P-Class (D=3) vs H2a (D=4)") |
| print("β" * 70) |
|
|
| |
| M_avg_p = all_M_p.mean(axis=0) |
| M_avg_h2 = all_M_h2.mean(axis=0) |
|
|
| sv_p = np.linalg.svd(M_avg_p, compute_uv=False) |
| sv_h2 = np.linalg.svd(M_avg_h2, compute_uv=False) |
|
|
| sv_p_norm = sv_p / sv_p.sum() |
| sv_h2_norm = sv_h2 / sv_h2.sum() |
|
|
| erank_p = math.exp(-(sv_p_norm * np.log(sv_p_norm + 1e-12)).sum()) |
| erank_h2 = math.exp(-(sv_h2_norm * np.log(sv_h2_norm + 1e-12)).sum()) |
|
|
| print(f"\n Effective rank of M_avg:") |
| print(f" P-Class (D=3): {erank_p:.2f} of {M_avg_p.shape[1]} possible") |
| print(f" H2a (D=4): {erank_h2:.2f} of {M_avg_h2.shape[1]} possible") |
| print(f" P uses {erank_p/M_avg_p.shape[1]*100:.0f}% of available dims") |
| print(f" H2 uses {erank_h2/M_avg_h2.shape[1]*100:.0f}% of available dims") |
|
|
| return { |
| 'effective_rank_p': float(erank_p), |
| 'effective_rank_h2': float(erank_h2), |
| 'p_dim_utilization': float(erank_p / M_avg_p.shape[1]), |
| 'h2_dim_utilization': float(erank_h2 / M_avg_h2.shape[1]), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def plot_diagnostic(all_M_p, all_M_h2, results, output_path): |
| fig = plt.figure(figsize=(18, 12)) |
|
|
| |
| ax1 = fig.add_subplot(2, 3, 1) |
| p_norms = np.linalg.norm(all_M_p, axis=2).flatten() |
| h2_norms = np.linalg.norm(all_M_h2, axis=2).flatten() |
| ax1.hist(p_norms, bins=50, alpha=0.5, label='P-Class', color='red') |
| ax1.hist(h2_norms, bins=50, alpha=0.5, label='H2a', color='blue') |
| ax1.axvline(1.0, color='black', linestyle='--', alpha=0.7, |
| label='unit sphere') |
| ax1.set_xlabel('Row norm') |
| ax1.set_ylabel('Count') |
| ax1.set_title('Per-sample row norms\n' |
| '(both should be ~1.0 if sphere-normed)') |
| ax1.legend() |
|
|
| |
| ax2 = fig.add_subplot(2, 3, 2, projection='3d') |
| sample_p = all_M_p[0] |
| ax2.scatter(sample_p[:, 0], sample_p[:, 1], sample_p[:, 2], |
| c=np.arange(32), cmap='viridis', s=80, |
| edgecolors='black', linewidths=0.5) |
| |
| u = np.linspace(0, 2 * np.pi, 20) |
| v = np.linspace(0, np.pi, 20) |
| x_s = np.outer(np.cos(u), np.sin(v)) |
| y_s = np.outer(np.sin(u), np.sin(v)) |
| z_s = np.outer(np.ones_like(u), np.cos(v)) |
| ax2.plot_wireframe(x_s, y_s, z_s, alpha=0.1, color='gray') |
| ax2.set_title(f'P-Class (D=3) β single sample\n32 rows in 3D') |
|
|
| |
| ax3 = fig.add_subplot(2, 3, 3, projection='3d') |
| sample_h2 = all_M_h2[0] |
| ax3.scatter(sample_h2[:, 0], sample_h2[:, 1], sample_h2[:, 2], |
| c=np.arange(32), cmap='viridis', s=80, |
| edgecolors='black', linewidths=0.5) |
| ax3.plot_wireframe(x_s, y_s, z_s, alpha=0.1, color='gray') |
| ax3.set_title(f'H2a (D=4) β single sample\n32 rows projected to first 3 dims') |
|
|
| |
| ax4 = fig.add_subplot(2, 3, 4) |
| sils_p = results['per_sample_clustering_p']['silhouettes_per_sample'] |
| sils_h2 = results['per_sample_clustering_h2']['silhouettes_per_sample'] |
| ax4.boxplot([sils_p, sils_h2], labels=['P-Class', 'H2a']) |
| ax4.axhline(0.5, color='red', linestyle='--', alpha=0.5, |
| label='strong cluster threshold') |
| ax4.set_ylabel(f'Silhouette score (k=5 per-sample)') |
| ax4.set_title('Per-sample cluster stability\n' |
| '(consistent silhouette = real cluster structure)') |
| ax4.legend(fontsize=8) |
| ax4.grid(alpha=0.3) |
|
|
| |
| ax5 = fig.add_subplot(2, 3, 5) |
| angles_p = results['angular_p']['pairwise_angles_subset'] |
| angles_h2 = results['angular_h2']['pairwise_angles_subset'] |
| ax5.hist(angles_p, bins=40, alpha=0.5, label='P-Class', color='red', |
| density=True) |
| ax5.hist(angles_h2, bins=40, alpha=0.5, label='H2a', color='blue', |
| density=True) |
| ax5.axvline(math.pi / 2, color='black', linestyle='--', alpha=0.7, |
| label='Ο/2 (uniform peak)') |
| ax5.set_xlabel('Pairwise angle (radians)') |
| ax5.set_ylabel('Density') |
| ax5.set_title('Pairwise angle distribution\n' |
| '(uniform sphere peaks at Ο/2)') |
| ax5.legend(fontsize=8) |
|
|
| |
| ax6 = fig.add_subplot(2, 3, 6) |
| stab_p = results['stability_p']['mean_dir_norms'] |
| stab_h2 = results['stability_h2']['mean_dir_norms'] |
| ax6.plot(sorted(stab_p, reverse=True), 'o-', label='P-Class', |
| color='red', markersize=5) |
| ax6.plot(sorted(stab_h2, reverse=True), 's-', label='H2a', |
| color='blue', markersize=5) |
| ax6.set_xlabel('Row index (sorted by stability)') |
| ax6.set_ylabel('Mean direction norm\n(1.0 = perfectly stable)') |
| ax6.set_title('Per-row stability across 512 samples\n' |
| '(low = row direction depends on input)') |
| ax6.legend() |
| ax6.grid(alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig(output_path, dpi=120, bbox_inches='tight') |
| plt.show() |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("Loading P-rank09 (D=3 candidate)...") |
| p_model, p_cfg = load_model('rank09', RANK09_CKPT) |
| print(f" V={p_cfg.matrix_v}, D={p_cfg.D}, params=" |
| f"{sum(p.numel() for p in p_model.parameters()):,}") |
|
|
| print("\nLoading Q-rank02 H2a (D=4 reference)...") |
| h2_model, h2_cfg = load_model('rank02', RANK02_CKPT) |
| print(f" V={h2_cfg.matrix_v}, D={h2_cfg.D}, params=" |
| f"{sum(p.numel() for p in h2_model.parameters()):,}") |
|
|
| print("\nCollecting M rows from gaussian inputs (P-Class)...") |
| all_M_p = collect_per_sample_M(p_model, p_cfg) |
| print(f" shape: {all_M_p.shape}") |
|
|
| print("Collecting M rows from gaussian inputs (H2a)...") |
| all_M_h2 = collect_per_sample_M(h2_model, h2_cfg) |
| print(f" shape: {all_M_h2.shape}") |
|
|
| print("\n" + "β" * 70) |
| print("SPHERE-NORM VERIFICATION") |
| print("β" * 70) |
|
|
| norms_p = test_sphere_norm(all_M_p, "P-Class (D=3)") |
| norms_h2 = test_sphere_norm(all_M_h2, "H2a (D=4)") |
|
|
| print("\n" + "β" * 70) |
| print("ROW STABILITY ACROSS SAMPLES") |
| print("β" * 70) |
|
|
| stab_p = test_row_stability(all_M_p, "P-Class (D=3)") |
| stab_h2 = test_row_stability(all_M_h2, "H2a (D=4)") |
|
|
| print("\n" + "β" * 70) |
| print("PER-SAMPLE CLUSTERING") |
| print("β" * 70) |
|
|
| cluster_p = test_per_sample_clustering(all_M_p, k_test=5) |
| cluster_h2 = test_per_sample_clustering(all_M_h2, k_test=5) |
|
|
| print("\n" + "β" * 70) |
| print("ANGULAR DISTRIBUTION") |
| print("β" * 70) |
|
|
| angular_p = test_angular_distribution(all_M_p, "P-Class (D=3)") |
| angular_h2 = test_angular_distribution(all_M_h2, "H2a (D=4)") |
|
|
| print("\n" + "β" * 70) |
| print("ANTIPODAL STRUCTURE") |
| print("β" * 70) |
|
|
| antipodal_p = test_antipodal(all_M_p, "P-Class (D=3)") |
| antipodal_h2 = test_antipodal(all_M_h2, "H2a (D=4)") |
|
|
| comparison = comparison_test(all_M_p, all_M_h2) |
|
|
| all_results = { |
| 'sphere_norm_p': norms_p, |
| 'sphere_norm_h2': norms_h2, |
| 'stability_p': stab_p, |
| 'stability_h2': stab_h2, |
| 'per_sample_clustering_p': cluster_p, |
| 'per_sample_clustering_h2': cluster_h2, |
| 'angular_p': angular_p, |
| 'angular_h2': angular_h2, |
| 'antipodal_p': antipodal_p, |
| 'antipodal_h2': antipodal_h2, |
| 'comparison': comparison, |
| } |
|
|
| |
| |
| |
|
|
| print("\n" + "β" * 70) |
| print("INTERPRETATION") |
| print("β" * 70) |
|
|
| p_normed = norms_p['sphere_normed_per_sample'] |
| h2_normed = norms_h2['sphere_normed_per_sample'] |
|
|
| print(f"\nSphere-norm per-sample:") |
| print(f" P-Class: {'YES' if p_normed else 'NO'} " |
| f"(mean norm {norms_p['row_norms_mean']:.3f})") |
| print(f" H2a: {'YES' if h2_normed else 'NO'} " |
| f"(mean norm {norms_h2['row_norms_mean']:.3f})") |
|
|
| print(f"\nPer-sample cluster strength (k=5 silhouette):") |
| print(f" P-Class: mean {cluster_p['mean_silhouette']:.3f}, " |
| f"std {cluster_p['std_silhouette']:.3f}") |
| print(f" H2a: mean {cluster_h2['mean_silhouette']:.3f}, " |
| f"std {cluster_h2['std_silhouette']:.3f}") |
|
|
| print(f"\nRow direction stability (1.0 = perfectly stable):") |
| print(f" P-Class: {stab_p['mean_dir_norms_mean']:.3f}") |
| print(f" H2a: {stab_h2['mean_dir_norms_mean']:.3f}") |
|
|
| print(f"\nAngular distribution mean (uniform = Ο/2 β 1.571):") |
| print(f" P-Class: {angular_p['mean_angle']:.3f}") |
| print(f" H2a: {angular_h2['mean_angle']:.3f}") |
|
|
| print(f"\nDimension utilization:") |
| print(f" P-Class: {comparison['p_dim_utilization']*100:.0f}% of {p_cfg.D}-D") |
| print(f" H2a: {comparison['h2_dim_utilization']*100:.0f}% of {h2_cfg.D}-D") |
|
|
| print(f"\nKEY QUESTIONS ANSWERED:") |
|
|
| if p_normed and cluster_p['mean_silhouette'] > 0.5: |
| print(f" β P-Class IS clustered per-sample (real structure)") |
| elif p_normed and cluster_p['mean_silhouette'] < 0.3: |
| print(f" β P-Class clusters were AVERAGING ARTIFACT") |
| print(f" Per-sample silhouette only {cluster_p['mean_silhouette']:.3f}") |
|
|
| if antipodal_p['fraction_with_antipode'] > 0.5: |
| print(f" β P-Class has antipodal structure " |
| f"({antipodal_p['estimated_antipodal_pairs']} pairs)") |
|
|
| with open(OUTPUT_JSON, 'w') as f: |
| json.dump(all_results, f, indent=2, default=str) |
| print(f"\nSaved: {OUTPUT_JSON}") |
|
|
| plot_diagnostic(all_M_p, all_M_h2, all_results, OUTPUT_PLOT) |
| print(f"Saved: {OUTPUT_PLOT}") |
|
|
| return all_results |
|
|
|
|
| if __name__ == '__main__': |
| results = main() |