Oguzz07 commited on
Commit
a47b09f
·
verified ·
1 Parent(s): eabf58d

Add causal_selection/benchmark.py

Browse files
Files changed (1) hide show
  1. causal_selection/benchmark.py +249 -0
causal_selection/benchmark.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main benchmark runner: orchestrates data generation, algorithm runs, feature extraction,
3
+ and result collection into a meta-dataset.
4
+ """
5
+ import os
6
+ import json
7
+ import time
8
+ import numpy as np
9
+ import pandas as pd
10
+ import logging
11
+ import warnings
12
+ from datetime import datetime
13
+
14
+ from causal_selection.data.generator import (
15
+ load_bn_model, get_true_dag_adjmat, dag_to_cpdag, sample_dataset,
16
+ SMALL_NETWORKS, MEDIUM_NETWORKS, LARGE_NETWORKS, ALL_NETWORKS,
17
+ SAMPLE_SIZES, SEEDS_PER_CONFIG, get_network_tier
18
+ )
19
+ from causal_selection.discovery.algorithms import run_algorithm, ALGORITHM_POOL
20
+ from causal_selection.discovery.evaluator import evaluate_algorithm_result
21
+ from causal_selection.features.extractor import extract_all_features, FEATURE_NAMES
22
+
23
+ warnings.filterwarnings('ignore')
24
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
25
+ logger = logging.getLogger(__name__)
26
+
27
+ RESULTS_DIR = '/app/causal_selection/data/results'
28
+ ALGO_NAMES = list(ALGORITHM_POOL.keys())
29
+
30
+ # Timeout per algorithm per dataset (seconds)
31
+ TIMEOUT_MAP = {
32
+ 'small': 60, # 1 min for small networks
33
+ 'medium': 180, # 3 min for medium networks
34
+ 'large': 300, # 5 min for large networks
35
+ }
36
+
37
+
38
+ def run_single_config(network, n_samples, seed, timeout_sec=300):
39
+ """Run all algorithms on a single (network, n_samples, seed) configuration.
40
+
41
+ Returns:
42
+ dict with:
43
+ - 'meta_features': dict of feature values
44
+ - 'metrics': dict of algo_name -> metrics dict
45
+ - 'config': dict with network, n_samples, seed
46
+ """
47
+ logger.info(f"=== {network} N={n_samples} seed={seed} ===")
48
+
49
+ # Load network and ground truth
50
+ model = load_bn_model(network)
51
+ true_dag, node_names = get_true_dag_adjmat(model)
52
+ true_cpdag = dag_to_cpdag(true_dag)
53
+
54
+ # Sample data
55
+ t0 = time.time()
56
+ df = sample_dataset(model, n_samples, seed=seed)
57
+ sample_time = time.time() - t0
58
+ logger.info(f" Sampled {df.shape} in {sample_time:.1f}s")
59
+
60
+ # Extract meta-features
61
+ t0 = time.time()
62
+ features = extract_all_features(df, n_probe_triplets=100)
63
+ feat_time = time.time() - t0
64
+ logger.info(f" Extracted {len(features)} features in {feat_time:.1f}s")
65
+
66
+ # Run all algorithms
67
+ algo_metrics = {}
68
+ for algo_name in ALGO_NAMES:
69
+ t0 = time.time()
70
+ result = run_algorithm(algo_name, df, timeout_sec=timeout_sec)
71
+ metrics = evaluate_algorithm_result(result, true_cpdag)
72
+ algo_metrics[algo_name] = metrics
73
+
74
+ status_str = metrics['status']
75
+ if status_str == 'success':
76
+ logger.info(f" {algo_name:15s}: SHD={metrics['shd']:3d} F1={metrics['skeleton_f1']:.3f} "
77
+ f"time={metrics['runtime']:.1f}s")
78
+ else:
79
+ logger.info(f" {algo_name:15s}: {status_str} time={metrics['runtime']:.1f}s")
80
+
81
+ return {
82
+ 'meta_features': features,
83
+ 'metrics': algo_metrics,
84
+ 'config': {
85
+ 'network': network,
86
+ 'n_samples': n_samples,
87
+ 'seed': seed,
88
+ 'n_variables': len(node_names),
89
+ 'n_true_edges': int(((true_cpdag + true_cpdag.T) > 0).sum() // 2),
90
+ }
91
+ }
92
+
93
+
94
+ def build_meta_dataset(networks=None, save_intermediate=True):
95
+ """Run full benchmark and build meta-dataset.
96
+
97
+ Returns:
98
+ X: pd.DataFrame of meta-features
99
+ Y_shd: pd.DataFrame of SHD per algorithm (columns = algo names)
100
+ Y_nshd: pd.DataFrame of normalized SHD
101
+ configs: list of config dicts
102
+ full_results: list of full result dicts
103
+ """
104
+ if networks is None:
105
+ networks = ALL_NETWORKS
106
+
107
+ all_features = []
108
+ all_shd = []
109
+ all_nshd = []
110
+ all_configs = []
111
+ full_results = []
112
+
113
+ total_configs = 0
114
+ for net in networks:
115
+ tier = get_network_tier(net)
116
+ n_sizes = len(SAMPLE_SIZES[tier])
117
+ total_configs += n_sizes * SEEDS_PER_CONFIG
118
+
119
+ logger.info(f"Starting benchmark: {len(networks)} networks, ~{total_configs} configs")
120
+ config_idx = 0
121
+
122
+ for network in networks:
123
+ tier = get_network_tier(network)
124
+ sample_sizes = SAMPLE_SIZES[tier]
125
+ timeout = TIMEOUT_MAP[tier]
126
+
127
+ for n_samples in sample_sizes:
128
+ for seed in range(SEEDS_PER_CONFIG):
129
+ config_idx += 1
130
+ logger.info(f"\n[{config_idx}/{total_configs}] "
131
+ f"{network} N={n_samples} seed={seed}")
132
+
133
+ try:
134
+ result = run_single_config(network, n_samples, seed,
135
+ timeout_sec=timeout)
136
+
137
+ # Extract feature vector
138
+ feat_row = {name: result['meta_features'].get(name, 0.0)
139
+ for name in FEATURE_NAMES}
140
+ all_features.append(feat_row)
141
+
142
+ # Extract SHD vector
143
+ shd_row = {}
144
+ nshd_row = {}
145
+ for algo in ALGO_NAMES:
146
+ m = result['metrics'][algo]
147
+ shd_row[algo] = m['shd']
148
+ nshd_row[algo] = m['normalized_shd']
149
+ all_shd.append(shd_row)
150
+ all_nshd.append(nshd_row)
151
+
152
+ # Config info
153
+ all_configs.append(result['config'])
154
+ full_results.append(result)
155
+
156
+ except Exception as e:
157
+ logger.error(f"FAILED config {network} N={n_samples} seed={seed}: {e}")
158
+ continue
159
+
160
+ # Save intermediate results periodically
161
+ if save_intermediate and config_idx % 5 == 0:
162
+ _save_intermediate(all_features, all_shd, all_nshd, all_configs)
163
+
164
+ # Build final DataFrames
165
+ X = pd.DataFrame(all_features, columns=FEATURE_NAMES)
166
+ Y_shd = pd.DataFrame(all_shd, columns=ALGO_NAMES)
167
+ Y_nshd = pd.DataFrame(all_nshd, columns=ALGO_NAMES)
168
+ configs_df = pd.DataFrame(all_configs)
169
+
170
+ # Save final results
171
+ os.makedirs(RESULTS_DIR, exist_ok=True)
172
+ X.to_csv(os.path.join(RESULTS_DIR, 'meta_features.csv'), index=False)
173
+ Y_shd.to_csv(os.path.join(RESULTS_DIR, 'shd_matrix.csv'), index=False)
174
+ Y_nshd.to_csv(os.path.join(RESULTS_DIR, 'normalized_shd_matrix.csv'), index=False)
175
+ configs_df.to_csv(os.path.join(RESULTS_DIR, 'configs.csv'), index=False)
176
+
177
+ # Save full results as JSON
178
+ _save_full_results(full_results)
179
+
180
+ logger.info(f"\n=== BENCHMARK COMPLETE ===")
181
+ logger.info(f"Total configs: {len(all_features)}")
182
+ logger.info(f"Meta-feature matrix: {X.shape}")
183
+ logger.info(f"SHD matrix: {Y_shd.shape}")
184
+ logger.info(f"Results saved to {RESULTS_DIR}")
185
+
186
+ return X, Y_shd, Y_nshd, configs_df, full_results
187
+
188
+
189
+ def _save_intermediate(features, shds, nshds, configs):
190
+ """Save intermediate results."""
191
+ os.makedirs(RESULTS_DIR, exist_ok=True)
192
+ pd.DataFrame(features).to_csv(os.path.join(RESULTS_DIR, 'meta_features_partial.csv'), index=False)
193
+ pd.DataFrame(shds).to_csv(os.path.join(RESULTS_DIR, 'shd_matrix_partial.csv'), index=False)
194
+ pd.DataFrame(nshds).to_csv(os.path.join(RESULTS_DIR, 'normalized_shd_partial.csv'), index=False)
195
+ pd.DataFrame(configs).to_csv(os.path.join(RESULTS_DIR, 'configs_partial.csv'), index=False)
196
+
197
+
198
+ def _save_full_results(results):
199
+ """Save full results (without numpy arrays)."""
200
+ serializable = []
201
+ for r in results:
202
+ entry = {
203
+ 'config': r['config'],
204
+ 'meta_features': {k: float(v) if isinstance(v, (np.floating, np.integer)) else v
205
+ for k, v in r['meta_features'].items()},
206
+ 'metrics': {}
207
+ }
208
+ for algo, m in r['metrics'].items():
209
+ entry['metrics'][algo] = {
210
+ k: float(v) if isinstance(v, (np.floating, np.integer)) else v
211
+ for k, v in m.items()
212
+ }
213
+ serializable.append(entry)
214
+
215
+ with open(os.path.join(RESULTS_DIR, 'full_results.json'), 'w') as f:
216
+ json.dump(serializable, f, indent=2)
217
+
218
+
219
+ if __name__ == '__main__':
220
+ import sys
221
+
222
+ # Allow selecting network tier from command line
223
+ tier = sys.argv[1] if len(sys.argv) > 1 else 'small'
224
+
225
+ if tier == 'small':
226
+ networks = SMALL_NETWORKS
227
+ elif tier == 'medium':
228
+ networks = MEDIUM_NETWORKS
229
+ elif tier == 'large':
230
+ networks = LARGE_NETWORKS
231
+ elif tier == 'all':
232
+ networks = ALL_NETWORKS
233
+ else:
234
+ networks = [tier] # single network name
235
+
236
+ logger.info(f"Running benchmark for tier: {tier} ({networks})")
237
+ X, Y_shd, Y_nshd, configs, results = build_meta_dataset(networks=networks)
238
+
239
+ # Print summary
240
+ print("\n" + "=" * 80)
241
+ print("BENCHMARK SUMMARY")
242
+ print("=" * 80)
243
+ print(f"\nMeta-feature matrix: {X.shape}")
244
+ print(f"SHD matrix: {Y_shd.shape}")
245
+ print(f"\nMean SHD per algorithm:")
246
+ print(Y_shd.mean().sort_values().to_string())
247
+ print(f"\nBest algorithm per config:")
248
+ best = Y_shd.idxmin(axis=1)
249
+ print(best.value_counts().to_string())