| |
| """ |
| Large-scale experiment runner for NeurIPS-quality results. |
| Runs synthetic (full grid), model-family ablation, and real-data experiments. |
| Includes sanity checks and bootstrap CIs. |
| """ |
| import os, sys, json, time, yaml, argparse |
| import numpy as np |
| from datetime import datetime |
| from collections import defaultdict |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from src.data import (generate_gamma_poisson_data, generate_gaussian_gaussian_data, |
| generate_gaussian_gamma_data, sample_deletions) |
| from src.model import PoissonGammaVI, GaussianGaussianVI, GaussianGammaMAP |
| from src.graph_utils import build_adjacency, compute_graph_stats |
| from src.metrics import (compute_all_metrics, compute_deletion_influence_by_distance, |
| fit_exponential_decay, compute_local_error) |
| from src.unlearning import one_step_downdate_poisson_gamma |
| from src.utils import FitResult, generate_run_id, generate_config_id, save_jsonl, ensure_dir |
|
|
|
|
| |
| |
| |
|
|
| def run_sanity_checks(model, edges, full_params, edge_to_del, exact_params, |
| local_params_by_R, ws_params, model_family, N, M, K): |
| """Run all sanity checks, return dict of results.""" |
| checks = {} |
| |
| |
| if hasattr(model, 'compute_elbo'): |
| try: |
| elbo_full = model.compute_elbo(edges, full_params) |
| checks['full_objective'] = elbo_full |
| checks['objective_finite'] = bool(np.isfinite(elbo_full)) |
| except: |
| checks['objective_finite'] = False |
| |
| |
| if model_family == 'poisson_gamma': |
| checks['params_positive'] = bool( |
| np.all(full_params['a'] > 0) and np.all(full_params['b'] > 0) and |
| np.all(full_params['c'] > 0) and np.all(full_params['d'] > 0)) |
| checks['params_no_nan'] = bool( |
| not np.any(np.isnan(full_params['a'])) and not np.any(np.isnan(full_params['b']))) |
| |
| |
| if model_family == 'poisson_gamma': |
| from scipy.special import digamma |
| a, b, c, d = full_params['a'], full_params['b'], full_params['c'], full_params['d'] |
| |
| resp_ok = True |
| for edge in edges[:min(20, len(edges))]: |
| i, j, x = edge |
| if x > 0: |
| log_r = digamma(a[i]) - np.log(b[i]) + digamma(c[j]) - np.log(d[j]) |
| log_r -= log_r.max() |
| r = np.exp(log_r) |
| r_sum = r.sum() |
| r /= r_sum |
| if abs(r.sum() - 1.0) > 1e-6: |
| resp_ok = False |
| break |
| checks['responsibilities_sum_to_one'] = resp_ok |
| |
| |
| from src.metrics import compute_all_param_vector |
| v_full = compute_all_param_vector(full_params, model_family) |
| v_exact = compute_all_param_vector(exact_params, model_family) |
| diff = np.linalg.norm(v_full - v_exact) |
| checks['exact_differs_from_full'] = bool(diff > 1e-10) |
| checks['exact_full_diff_norm'] = float(diff) |
| |
| |
| errors_by_R = {} |
| for R, lp in sorted(local_params_by_R.items()): |
| err = compute_local_error(lp, exact_params, model_family) |
| errors_by_R[R] = err['relative_error'] |
| checks['errors_by_R'] = errors_by_R |
| if len(errors_by_R) >= 2: |
| R_list = sorted(errors_by_R.keys()) |
| checks['error_decreases_with_R'] = bool(errors_by_R[R_list[-1]] <= errors_by_R[R_list[0]]) |
| |
| |
| if ws_params is not None and max(local_params_by_R.keys()) >= 4: |
| ws_err = compute_local_error(ws_params, exact_params, model_family) |
| r4_err = errors_by_R.get(4, None) |
| if r4_err is not None: |
| checks['ws_error'] = ws_err['relative_error'] |
| checks['ws_close_to_R4'] = bool(ws_err['relative_error'] <= r4_err * 5 + 0.01) |
| |
| return checks |
|
|
|
|
| |
| |
| |
|
|
| def build_full_synthetic_configs(): |
| """Full synthetic grid: 3 graph × 3 degree × 3 count × 2 prior × 3 K = 162 configs.""" |
| configs = [] |
| N, M = 300, 300 |
| radii = [1, 2, 3, 4] |
| num_del = 50 |
| |
| for K in [5, 10, 20]: |
| for gt in ['bounded_degree', 'erdos_renyi', 'power_law']: |
| for deg_name, deg in [('low', 5), ('medium', 10), ('high', 20)]: |
| for cs_name, cs in [('low', 0.5), ('medium', 1.0), ('high', 3.0)]: |
| for ps_name, ps in [('strong', {'a0':1.0,'b0':1.0,'c0':1.0,'d0':1.0}), |
| ('weak', {'a0':0.1,'b0':0.1,'c0':0.1,'d0':0.1})]: |
| configs.append({ |
| 'N': N, 'M': M, 'K': K, |
| 'graph_type': gt, 'avg_degree': deg, |
| 'count_scale': cs, 'count_scale_label': cs_name, |
| 'prior_strength': ps_name, 'prior_config': ps, |
| 'num_deletions': num_del, 'radii': radii, 'seed': 42, |
| 'model_family': 'poisson_gamma', 'max_iter': 300, 'tol': 1e-5, |
| }) |
| return configs |
|
|
|
|
| def build_model_family_configs(): |
| """Model-family ablation: balanced across 3 families.""" |
| configs = [] |
| N, M = 200, 200 |
| radii = [1, 2, 3, 4] |
| num_del = 30 |
| |
| for K in [5, 10]: |
| for gt in ['bounded_degree', 'erdos_renyi', 'power_law']: |
| for deg in [5, 15]: |
| |
| for cs_name, cs in [('low', 0.5), ('medium', 1.0), ('high', 3.0)]: |
| for ps_name, ps in [('strong', {'a0':1.0,'b0':1.0,'c0':1.0,'d0':1.0}), |
| ('weak', {'a0':0.3,'b0':0.3,'c0':0.3,'d0':0.3})]: |
| configs.append({ |
| 'N': N, 'M': M, 'K': K, |
| 'graph_type': gt, 'avg_degree': deg, |
| 'count_scale': cs, 'count_scale_label': cs_name, |
| 'prior_strength': ps_name, 'prior_config': ps, |
| 'num_deletions': num_del, 'radii': radii, 'seed': 42, |
| 'model_family': 'poisson_gamma', 'max_iter': 300, 'tol': 1e-5, |
| }) |
| |
| |
| for sx_name, sx in [('high_noise', 2.0), ('medium_noise', 1.0), ('low_noise', 0.3)]: |
| for sp_name, sp in [('strong_prior', 0.5), ('weak_prior', 3.0)]: |
| configs.append({ |
| 'N': N, 'M': M, 'K': K, |
| 'graph_type': gt, 'avg_degree': deg, |
| 'sigma_x': sx, 'sigma_x_label': sx_name, |
| 'sigma_U': sp, 'sigma_V': sp, |
| 'prior_strength': sp_name, |
| 'num_deletions': num_del, 'radii': radii, 'seed': 42, |
| 'model_family': 'gaussian_gaussian', 'max_iter': 300, 'tol': 1e-5, |
| }) |
| |
| |
| for sx_name, sx in [('high_noise', 2.0), ('medium_noise', 1.0), ('low_noise', 0.3)]: |
| for gp_name, gp in [('strong', {'a0':2.0,'b0':2.0,'c0':2.0,'d0':2.0}), |
| ('weak', {'a0':0.3,'b0':0.3,'c0':0.3,'d0':0.3})]: |
| configs.append({ |
| 'N': N, 'M': M, 'K': K, |
| 'graph_type': gt, 'avg_degree': deg, |
| 'sigma_x': sx, 'sigma_x_label': sx_name, |
| 'prior_strength': gp_name, 'prior_config': gp, |
| 'num_deletions': num_del, 'radii': radii, 'seed': 42, |
| 'model_family': 'gaussian_gamma_map', |
| 'lr': 0.05, 'max_iter': 2000, 'tol': 1e-6, 'grad_clip': 10.0, |
| }) |
| |
| return configs |
|
|
|
|
| |
| |
| |
|
|
| def run_config(config): |
| """Run one configuration end-to-end with sanity checks.""" |
| model_family = config['model_family'] |
| gt = config['graph_type'] |
| N, M, K = config['N'], config['M'], config['K'] |
| avg_degree = config['avg_degree'] |
| radii = config.get('radii', [1, 2, 3, 4]) |
| num_del = config.get('num_deletions', 50) |
| seed = config.get('seed', 42) |
| max_iter = config.get('max_iter', 300) |
| tol = config.get('tol', 1e-5) |
| |
| config_id = generate_config_id(config) |
| run_id = generate_run_id() |
| |
| prior_cfg = config.get('prior_config', {}) |
| a0 = prior_cfg.get('a0', 0.3) |
| b0 = prior_cfg.get('b0', 1.0) |
| c0 = prior_cfg.get('c0', 0.3) |
| d0 = prior_cfg.get('d0', 1.0) |
| count_scale = config.get('count_scale', 1.0) |
| prior_strength = config.get('prior_strength', 'strong') |
| |
| |
| if model_family == 'poisson_gamma': |
| edges, U_true, V_true, ge = generate_gamma_poisson_data( |
| N, M, K, gt, avg_degree, count_scale, a0, b0, c0, d0, seed=seed) |
| elif model_family == 'gaussian_gaussian': |
| sigma_U = config.get('sigma_U', 1.0) |
| sigma_V = config.get('sigma_V', 1.0) |
| sigma_x = config.get('sigma_x', 1.0) |
| edges, U_true, V_true, ge = generate_gaussian_gaussian_data( |
| N, M, K, gt, avg_degree, sigma_U, sigma_V, sigma_x, seed=seed) |
| elif model_family == 'gaussian_gamma_map': |
| sigma_x = config.get('sigma_x', 1.0) |
| edges, U_true, V_true, ge = generate_gaussian_gamma_data( |
| N, M, K, gt, avg_degree, a0, b0, c0, d0, sigma_x, seed=seed) |
| |
| if len(edges) < 10: |
| print(f" SKIP: only {len(edges)} edges") |
| return [] |
| |
| |
| if model_family == 'poisson_gamma': |
| model = PoissonGammaVI(N, M, K, a0, b0, c0, d0, max_iter=max_iter, tol=tol, seed=seed) |
| elif model_family == 'gaussian_gaussian': |
| model = GaussianGaussianVI(N, M, K, sigma_U=config.get('sigma_U', 1.0), |
| sigma_V=config.get('sigma_V', 1.0), |
| sigma_x=config.get('sigma_x', 1.0), |
| max_iter=max_iter, tol=tol, seed=seed) |
| elif model_family == 'gaussian_gamma_map': |
| model = GaussianGammaMAP(N, M, K, a0, b0, c0, d0, |
| sigma_x=config.get('sigma_x', 1.0), |
| lr=config.get('lr', 0.05), |
| max_iter=max_iter, tol=tol, seed=seed, |
| grad_clip=config.get('grad_clip', 10.0)) |
| |
| |
| t0 = time.time() |
| full_result = model.fit_full(edges) |
| t_full = time.time() - t0 |
| full_params = full_result.params |
| |
| |
| u2i, i2u, ed = build_adjacency(edges, N, M) |
| dels = sample_deletions(edges, u2i, i2u, num_del, seed=seed) |
| |
| records = [] |
| sanity_results = [] |
| |
| for del_idx, (edge_to_del, del_type) in enumerate(dels): |
| i_del, j_del, x_del = edge_to_del |
| |
| |
| exact_result = model.fit_without_edge(edges, edge_to_del, init_params=full_params) |
| |
| |
| local_results = {} |
| local_params = {} |
| for R in radii: |
| lr = model.fit_local(edges, edge_to_del, R, init_params=full_params) |
| local_results[R] = lr |
| local_params[R] = lr.params |
| |
| |
| ws_result = model.fit_warm_start_global(edges, edge_to_del, init_params=full_params) |
| |
| |
| one_step_params = None |
| one_step_runtime = None |
| if model_family == 'poisson_gamma': |
| os_res = one_step_downdate_poisson_gamma( |
| edges, edge_to_del, full_params, N, M, K, a0, b0, c0, d0) |
| one_step_params = os_res.params |
| one_step_runtime = os_res.runtime_sec |
| |
| |
| model_kwargs = {} |
| if model_family == 'poisson_gamma': |
| model_kwargs = {'a0': a0, 'b0': b0, 'c0': c0, 'd0': d0} |
| else: |
| model_kwargs = {'sigma_x': config.get('sigma_x', 1.0)} |
| |
| metrics = compute_all_metrics( |
| full_params, exact_result.params, local_params, |
| ws_result.params, one_step_params, |
| edge_to_del, edges, N, M, K, |
| model_family, radii=radii, model_kwargs=model_kwargs) |
| |
| |
| if del_idx < 3: |
| sanity = run_sanity_checks( |
| model, edges, full_params, edge_to_del, |
| exact_result.params, local_params, ws_result.params, |
| model_family, N, M, K) |
| sanity_results.append(sanity) |
| |
| |
| record = { |
| 'run_id': run_id, 'config_id': config_id, |
| 'dataset_type': 'synthetic', 'dataset_name': f'synthetic_{model_family}', |
| 'model_family': model_family, |
| 'inference_type': 'vi' if model_family != 'gaussian_gamma_map' else 'map', |
| 'likelihood': 'poisson' if model_family == 'poisson_gamma' else 'gaussian', |
| 'prior': 'gamma' if 'gamma' in model_family else 'gaussian', |
| 'graph_type': gt, 'seed': seed, 'N': N, 'M': M, 'K': K, |
| 'avg_degree': avg_degree, |
| 'count_scale': count_scale if model_family == 'poisson_gamma' else None, |
| 'prior_strength': prior_strength, |
| 'deletion_edge': [int(i_del), int(j_del), float(x_del)], |
| 'deletion_type': del_type, 'deletion_index': del_idx, |
| 'runtime_full': t_full, 'runtime_exact': exact_result.runtime_sec, |
| 'runtime_warm_start': ws_result.runtime_sec, |
| 'runtime_one_step': one_step_runtime, |
| 'exact_converged': exact_result.converged, |
| 'exact_iterations': exact_result.n_iterations, |
| 'full_converged': full_result.converged, |
| } |
| |
| for R in radii: |
| record[f'runtime_local_R{R}'] = local_results[R].runtime_sec |
| record[f'local_R{R}_converged'] = local_results[R].converged |
| record[f'local_R{R}_iterations'] = local_results[R].n_iterations |
| |
| record.update(metrics) |
| if 'influence_by_distance' in record: |
| for d_str, val in record['influence_by_distance'].items(): |
| record[f'influence_d{d_str}'] = val |
| |
| record['regime'] = f"{gt}_{prior_strength}_deg{avg_degree}" |
| if model_family == 'poisson_gamma': |
| record['regime'] += f"_cs{count_scale}" |
| record['a0'] = a0; record['b0'] = b0; record['c0'] = c0; record['d0'] = d0 |
| if model_family in ('gaussian_gaussian', 'gaussian_gamma_map'): |
| record['sigma_x'] = config.get('sigma_x', 1.0) |
| if model_family == 'gaussian_gaussian': |
| record['sigma_U'] = config.get('sigma_U', 1.0) |
| record['sigma_V'] = config.get('sigma_V', 1.0) |
| |
| records.append(record) |
| |
| return records, sanity_results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--mode', type=str, default='full_synthetic', |
| choices=['full_synthetic', 'model_family', 'both']) |
| parser.add_argument('--max_configs', type=int, default=None) |
| args = parser.parse_args() |
| |
| if args.mode in ('full_synthetic', 'both'): |
| configs = build_full_synthetic_configs() |
| label = 'full_synthetic' |
| elif args.mode == 'model_family': |
| configs = build_model_family_configs() |
| label = 'model_family_v2' |
| |
| if args.mode == 'both': |
| configs += build_model_family_configs() |
| label = 'all' |
| |
| if args.max_configs: |
| configs = configs[:args.max_configs] |
| |
| print(f"Running {len(configs)} configs ({args.mode})") |
| |
| output_dir = ensure_dir('results/raw') |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| output_file = os.path.join(output_dir, f'{label}_{ts}.jsonl') |
| sanity_file = os.path.join(output_dir, f'sanity_{label}_{ts}.jsonl') |
| |
| total_records = 0 |
| all_sanity = [] |
| |
| for idx, config in enumerate(configs): |
| mf = config['model_family'] |
| gt = config['graph_type'] |
| K = config['K'] |
| print(f"\n[{idx+1}/{len(configs)}] {mf} {gt} K={K} deg={config['avg_degree']} ps={config.get('prior_strength','')}") |
| |
| try: |
| records, sanity = run_config(config) |
| total_records += len(records) |
| all_sanity.extend(sanity) |
| save_jsonl(records, output_file) |
| print(f" -> {len(records)} records (total: {total_records})") |
| except Exception as e: |
| print(f" ERROR: {e}") |
| import traceback; traceback.print_exc() |
| |
| |
| save_jsonl(all_sanity, sanity_file) |
| |
| print(f"\n{'='*60}") |
| print(f"Done. {total_records} records in {output_file}") |
| print(f"Sanity checks: {len(all_sanity)} in {sanity_file}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|