Oguzz07 commited on
Commit
fbf08d1
·
verified ·
1 Parent(s): a47b09f

Add run_benchmark.py

Browse files
Files changed (1) hide show
  1. run_benchmark.py +212 -0
run_benchmark.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Resume benchmark from partial results, then run medium and large networks too.
3
+ Optimized for CPU: reduced timeouts, skip heavy combos.
4
+ """
5
+ import os
6
+ import sys
7
+ import time
8
+ import numpy as np
9
+ import pandas as pd
10
+ import json
11
+ import logging
12
+ import warnings
13
+
14
+ warnings.filterwarnings('ignore')
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Suppress the verbose BOSS/GRaSP output
19
+ logging.getLogger('causallearn').setLevel(logging.WARNING)
20
+
21
+ sys.path.insert(0, '/app')
22
+ from causal_selection.data.generator import (
23
+ load_bn_model, get_true_dag_adjmat, dag_to_cpdag, sample_dataset,
24
+ SMALL_NETWORKS, MEDIUM_NETWORKS, LARGE_NETWORKS, get_network_tier
25
+ )
26
+ from causal_selection.discovery.algorithms import run_algorithm, ALGORITHM_POOL
27
+ from causal_selection.discovery.evaluator import evaluate_algorithm_result
28
+ from causal_selection.features.extractor import extract_all_features, FEATURE_NAMES
29
+
30
+ ALGO_NAMES = list(ALGORITHM_POOL.keys())
31
+ RESULTS_DIR = '/app/causal_selection/data/results'
32
+
33
+ # More aggressive sample sizes - fewer but covering range
34
+ SAMPLE_SIZES_FAST = {
35
+ 'small': [500, 1000, 2000, 5000],
36
+ 'medium': [500, 1000, 2000],
37
+ 'large': [500, 1000],
38
+ }
39
+
40
+ # Per-algorithm timeout (seconds) - algorithm-specific!
41
+ ALGO_TIMEOUT = {
42
+ 'PC_discrete': {'small': 30, 'medium': 120, 'large': 300},
43
+ 'FCI': {'small': 30, 'medium': 120, 'large': 300},
44
+ 'GES': {'small': 30, 'medium': 120, 'large': 300},
45
+ 'BOSS': {'small': 30, 'medium': 120, 'large': 300},
46
+ 'GRaSP': {'small': 30, 'medium': 120, 'large': 300},
47
+ 'HC': {'small': 30, 'medium': 60, 'large': 120},
48
+ 'Tabu': {'small': 30, 'medium': 60, 'large': 120},
49
+ 'MMHC': {'small': 30, 'medium': 60, 'large': 120},
50
+ 'K2': {'small': 20, 'medium': 30, 'large': 60},
51
+ }
52
+
53
+ SEEDS = 2 # Reduced from 3 to speed up
54
+
55
+
56
+ def load_existing_results():
57
+ """Load existing partial results to avoid re-running."""
58
+ existing = set()
59
+ partial_path = os.path.join(RESULTS_DIR, 'configs_partial.csv')
60
+ final_path = os.path.join(RESULTS_DIR, 'configs.csv')
61
+
62
+ for path in [partial_path, final_path]:
63
+ if os.path.exists(path):
64
+ df = pd.read_csv(path)
65
+ for _, row in df.iterrows():
66
+ key = (row['network'], int(row['n_samples']), int(row['seed']))
67
+ existing.add(key)
68
+
69
+ return existing
70
+
71
+
72
+ def run_benchmark():
73
+ """Run full benchmark with resume capability."""
74
+ existing = load_existing_results()
75
+ logger.info(f"Found {len(existing)} existing configs")
76
+
77
+ # Load existing partial data
78
+ all_features = []
79
+ all_shd = []
80
+ all_nshd = []
81
+ all_configs = []
82
+
83
+ for prefix in ['meta_features_partial', 'meta_features']:
84
+ path = os.path.join(RESULTS_DIR, f'{prefix}.csv')
85
+ if os.path.exists(path):
86
+ df = pd.read_csv(path)
87
+ all_features = df.to_dict('records')
88
+ break
89
+
90
+ for prefix in ['shd_matrix_partial', 'shd_matrix']:
91
+ path = os.path.join(RESULTS_DIR, f'{prefix}.csv')
92
+ if os.path.exists(path):
93
+ df = pd.read_csv(path)
94
+ all_shd = df.to_dict('records')
95
+ break
96
+
97
+ for prefix in ['normalized_shd_partial', 'normalized_shd_matrix']:
98
+ path = os.path.join(RESULTS_DIR, f'{prefix}.csv')
99
+ if os.path.exists(path):
100
+ df = pd.read_csv(path)
101
+ all_nshd = df.to_dict('records')
102
+ break
103
+
104
+ for prefix in ['configs_partial', 'configs']:
105
+ path = os.path.join(RESULTS_DIR, f'{prefix}.csv')
106
+ if os.path.exists(path):
107
+ df = pd.read_csv(path)
108
+ all_configs = df.to_dict('records')
109
+ break
110
+
111
+ logger.info(f"Starting with {len(all_configs)} existing results")
112
+
113
+ # Generate all configs to run
114
+ all_networks = SMALL_NETWORKS + MEDIUM_NETWORKS + LARGE_NETWORKS
115
+ configs_to_run = []
116
+
117
+ for net in all_networks:
118
+ tier = get_network_tier(net)
119
+ for n_samples in SAMPLE_SIZES_FAST[tier]:
120
+ for seed in range(SEEDS):
121
+ key = (net, n_samples, seed)
122
+ if key not in existing:
123
+ configs_to_run.append((net, n_samples, seed, tier))
124
+
125
+ logger.info(f"Configs to run: {len(configs_to_run)}")
126
+
127
+ total = len(configs_to_run)
128
+ for idx, (network, n_samples, seed, tier) in enumerate(configs_to_run):
129
+ logger.info(f"\n[{idx+1}/{total}] {network} N={n_samples} seed={seed}")
130
+
131
+ try:
132
+ # Load network
133
+ model = load_bn_model(network)
134
+ true_dag, node_names = get_true_dag_adjmat(model)
135
+ true_cpdag = dag_to_cpdag(true_dag)
136
+
137
+ # Sample data
138
+ df = sample_dataset(model, n_samples, seed=seed)
139
+
140
+ # Extract features
141
+ features = extract_all_features(df, n_probe_triplets=80)
142
+
143
+ # Run algorithms with per-algo timeout
144
+ algo_metrics = {}
145
+ for algo_name in ALGO_NAMES:
146
+ timeout = ALGO_TIMEOUT[algo_name][tier]
147
+ result = run_algorithm(algo_name, df, timeout_sec=timeout)
148
+ metrics = evaluate_algorithm_result(result, true_cpdag)
149
+ algo_metrics[algo_name] = metrics
150
+
151
+ s = metrics['status']
152
+ if s == 'success':
153
+ logger.info(f" {algo_name:12s}: SHD={metrics['shd']:3d} F1={metrics['skeleton_f1']:.3f} t={metrics['runtime']:.1f}s")
154
+ else:
155
+ logger.info(f" {algo_name:12s}: {s} t={metrics['runtime']:.1f}s")
156
+
157
+ # Store results
158
+ feat_row = {name: features.get(name, 0.0) for name in FEATURE_NAMES}
159
+ all_features.append(feat_row)
160
+
161
+ shd_row = {algo: algo_metrics[algo]['shd'] for algo in ALGO_NAMES}
162
+ nshd_row = {algo: algo_metrics[algo]['normalized_shd'] for algo in ALGO_NAMES}
163
+ all_shd.append(shd_row)
164
+ all_nshd.append(nshd_row)
165
+
166
+ config = {
167
+ 'network': network,
168
+ 'n_samples': n_samples,
169
+ 'seed': seed,
170
+ 'n_variables': len(node_names),
171
+ 'n_true_edges': int(((true_cpdag + true_cpdag.T) > 0).sum() // 2),
172
+ }
173
+ all_configs.append(config)
174
+
175
+ # Save periodically
176
+ if (idx + 1) % 3 == 0:
177
+ _save_results(all_features, all_shd, all_nshd, all_configs, partial=True)
178
+
179
+ except Exception as e:
180
+ logger.error(f"FAILED {network} N={n_samples} seed={seed}: {e}")
181
+ import traceback
182
+ traceback.print_exc()
183
+
184
+ # Final save
185
+ _save_results(all_features, all_shd, all_nshd, all_configs, partial=False)
186
+
187
+ # Print summary
188
+ Y_shd = pd.DataFrame(all_shd)
189
+ configs_df = pd.DataFrame(all_configs)
190
+
191
+ print("\n" + "=" * 80)
192
+ print("BENCHMARK COMPLETE")
193
+ print("=" * 80)
194
+ print(f"Total configs: {len(all_configs)}")
195
+ print(f"Networks: {configs_df['network'].unique()}")
196
+ print(f"\nMean SHD per algorithm:")
197
+ print(Y_shd.mean().sort_values())
198
+ print(f"\nBest algorithm count:")
199
+ print(Y_shd.idxmin(axis=1).value_counts())
200
+
201
+
202
+ def _save_results(features, shds, nshds, configs, partial=True):
203
+ os.makedirs(RESULTS_DIR, exist_ok=True)
204
+ suffix = '_partial' if partial else ''
205
+ pd.DataFrame(features).to_csv(os.path.join(RESULTS_DIR, f'meta_features{suffix}.csv'), index=False)
206
+ pd.DataFrame(shds).to_csv(os.path.join(RESULTS_DIR, f'shd_matrix{suffix}.csv'), index=False)
207
+ pd.DataFrame(nshds).to_csv(os.path.join(RESULTS_DIR, f'normalized_shd_matrix{suffix}.csv'), index=False)
208
+ pd.DataFrame(configs).to_csv(os.path.join(RESULTS_DIR, f'configs{suffix}.csv'), index=False)
209
+
210
+
211
+ if __name__ == '__main__':
212
+ run_benchmark()