Oguzz07 commited on
Commit
845e234
·
verified ·
1 Parent(s): 7bf8978

Update with augmented data (178 configs) + pairwise ranking model (71.3% hit rate)

Browse files
Files changed (1) hide show
  1. augment_and_improve.py +517 -0
augment_and_improve.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive data augmentation and model improvement pipeline.
3
+
4
+ Augmentation strategies:
5
+ 1. Variable subsampling: randomly drop variables to create new graph topologies
6
+ 2. Sample-size variation: subsample rows from existing large-N datasets
7
+ 3. Noise injection: add random noise to some variables
8
+
9
+ Then trains multiple model architectures and does a full comparison.
10
+ """
11
+ import os
12
+ import sys
13
+ import numpy as np
14
+ import pandas as pd
15
+ import json
16
+ import logging
17
+ import warnings
18
+ import time
19
+ from itertools import combinations
20
+
21
+ warnings.filterwarnings('ignore')
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
23
+ logger = logging.getLogger(__name__)
24
+ logging.getLogger('causallearn').setLevel(logging.ERROR)
25
+
26
+ sys.path.insert(0, '/app')
27
+ from causal_selection.data.generator import (
28
+ load_bn_model, get_true_dag_adjmat, dag_to_cpdag, sample_dataset,
29
+ ALL_NETWORKS, MEDIUM_NETWORKS, LARGE_NETWORKS, get_network_tier
30
+ )
31
+ from causal_selection.discovery.algorithms import run_algorithm, ALGORITHM_POOL
32
+ from causal_selection.discovery.evaluator import evaluate_algorithm_result
33
+ from causal_selection.features.extractor import extract_all_features, FEATURE_NAMES
34
+ from causal_selection.meta_learner.trainer import (
35
+ load_meta_dataset, evaluate_lono_cv, train_meta_learner,
36
+ save_model, get_feature_importance, ALGO_NAMES, RESULTS_DIR
37
+ )
38
+
39
+ from sklearn.ensemble import (
40
+ RandomForestRegressor, GradientBoostingRegressor,
41
+ RandomForestClassifier, GradientBoostingClassifier
42
+ )
43
+ from sklearn.multioutput import MultiOutputRegressor
44
+ from sklearn.preprocessing import StandardScaler
45
+ from sklearn.metrics import mean_squared_error
46
+ import joblib
47
+
48
+
49
+ # ==============================================================
50
+ # AUGMENTATION
51
+ # ==============================================================
52
+
53
+ def augment_all(networks_for_varsub=None, n_varsub=3, drop_frac=0.3,
54
+ networks_for_samplesub=None, n_samplesub=2):
55
+ """Run all augmentation strategies and return combined augmented data."""
56
+
57
+ all_feats, all_shds, all_nshds, all_cfgs = [], [], [], []
58
+
59
+ # Strategy 1: Variable subsampling
60
+ logger.info("="*60)
61
+ logger.info("AUGMENTATION: Variable Subsampling")
62
+ logger.info("="*60)
63
+
64
+ if networks_for_varsub is None:
65
+ # Only networks with >8 variables
66
+ networks_for_varsub = ['sachs', 'alarm', 'child', 'insurance',
67
+ 'water', 'barley', 'mildew',
68
+ 'hailfinder', 'hepar2']
69
+
70
+ for net_name in networks_for_varsub:
71
+ try:
72
+ model = load_bn_model(net_name)
73
+ true_dag, node_names = get_true_dag_adjmat(model)
74
+ n_vars = len(node_names)
75
+
76
+ if n_vars < 8:
77
+ continue
78
+
79
+ tier = get_network_tier(net_name)
80
+ timeout = {'small': 60, 'medium': 90, 'large': 120}[tier]
81
+
82
+ for aug_i in range(n_varsub):
83
+ rng = np.random.RandomState(200 + aug_i * 100 + hash(net_name) % 100)
84
+
85
+ # Keep 60-80% of variables
86
+ keep_frac = 1.0 - drop_frac + rng.uniform(-0.1, 0.1)
87
+ keep_frac = max(0.5, min(0.85, keep_frac))
88
+ n_keep = max(5, int(n_vars * keep_frac))
89
+ keep_idx = sorted(rng.choice(n_vars, n_keep, replace=False))
90
+
91
+ sub_dag = true_dag[np.ix_(keep_idx, keep_idx)]
92
+ sub_cpdag = dag_to_cpdag(sub_dag)
93
+ sub_names = [node_names[i] for i in keep_idx]
94
+
95
+ n_samples = rng.choice([500, 1000, 2000])
96
+ df_full = sample_dataset(model, n_samples, seed=200 + aug_i)
97
+ df_sub = df_full[sub_names].copy()
98
+ df_sub.columns = [f'X{i}' for i in range(len(sub_names))]
99
+
100
+ logger.info(f" VarSub {net_name} #{aug_i}: {n_vars}->{n_keep} vars, N={n_samples}")
101
+
102
+ f, s, ns, c = _run_single(df_sub, sub_cpdag,
103
+ f'{net_name}_vs{aug_i}', n_samples,
104
+ 200+aug_i, n_keep, timeout)
105
+ if f is not None:
106
+ all_feats.append(f)
107
+ all_shds.append(s)
108
+ all_nshds.append(ns)
109
+ all_cfgs.append(c)
110
+
111
+ except Exception as e:
112
+ logger.error(f"VarSub failed for {net_name}: {e}")
113
+
114
+ # Strategy 2: Sample-size subsampling from existing large-N datasets
115
+ logger.info("\n" + "="*60)
116
+ logger.info("AUGMENTATION: Sample Size Variation")
117
+ logger.info("="*60)
118
+
119
+ if networks_for_samplesub is None:
120
+ networks_for_samplesub = ['asia', 'cancer', 'earthquake', 'sachs',
121
+ 'survey', 'alarm', 'child']
122
+
123
+ sub_sample_sizes = [300, 750, 1500, 3000]
124
+
125
+ for net_name in networks_for_samplesub:
126
+ try:
127
+ model = load_bn_model(net_name)
128
+ true_dag, node_names = get_true_dag_adjmat(model)
129
+ true_cpdag = dag_to_cpdag(true_dag)
130
+ n_vars = len(node_names)
131
+ tier = get_network_tier(net_name)
132
+ timeout = {'small': 60, 'medium': 90, 'large': 120}[tier]
133
+
134
+ for ss_i, n_samples in enumerate(sub_sample_sizes):
135
+ seed = 300 + ss_i
136
+ df = sample_dataset(model, n_samples, seed=seed)
137
+
138
+ logger.info(f" SampleSub {net_name} N={n_samples} seed={seed}")
139
+
140
+ f, s, ns, c = _run_single(df, true_cpdag,
141
+ f'{net_name}_ss{ss_i}', n_samples,
142
+ seed, n_vars, timeout)
143
+ if f is not None:
144
+ all_feats.append(f)
145
+ all_shds.append(s)
146
+ all_nshds.append(ns)
147
+ all_cfgs.append(c)
148
+
149
+ except Exception as e:
150
+ logger.error(f"SampleSub failed for {net_name}: {e}")
151
+
152
+ # Strategy 3: Noise injection on small networks
153
+ logger.info("\n" + "="*60)
154
+ logger.info("AUGMENTATION: Noise Injection")
155
+ logger.info("="*60)
156
+
157
+ noise_networks = ['asia', 'sachs', 'survey', 'cancer', 'earthquake']
158
+
159
+ for net_name in noise_networks:
160
+ try:
161
+ model = load_bn_model(net_name)
162
+ true_dag, node_names = get_true_dag_adjmat(model)
163
+ true_cpdag = dag_to_cpdag(true_dag)
164
+ n_vars = len(node_names)
165
+ timeout = 60
166
+
167
+ for noise_i, noise_frac in enumerate([0.05, 0.10]):
168
+ seed = 400 + noise_i
169
+ n_samples = 1000
170
+ df = sample_dataset(model, n_samples, seed=seed)
171
+
172
+ # Inject random category flips
173
+ rng = np.random.RandomState(seed)
174
+ n_flip = int(n_samples * n_vars * noise_frac)
175
+ for _ in range(n_flip):
176
+ r = rng.randint(n_samples)
177
+ c = rng.randint(n_vars)
178
+ max_val = df.iloc[:, c].max()
179
+ df.iloc[r, c] = rng.randint(0, max_val + 1)
180
+
181
+ logger.info(f" Noise {net_name} frac={noise_frac}")
182
+
183
+ f, s, ns, c = _run_single(df, true_cpdag,
184
+ f'{net_name}_n{noise_i}', n_samples,
185
+ seed, n_vars, timeout)
186
+ if f is not None:
187
+ all_feats.append(f)
188
+ all_shds.append(s)
189
+ all_nshds.append(ns)
190
+ all_cfgs.append(c)
191
+
192
+ except Exception as e:
193
+ logger.error(f"Noise failed for {net_name}: {e}")
194
+
195
+ return all_feats, all_shds, all_nshds, all_cfgs
196
+
197
+
198
+ def _run_single(df, true_cpdag, net_label, n_samples, seed, n_vars, timeout):
199
+ """Run feature extraction + all algorithms on one config."""
200
+ try:
201
+ features = extract_all_features(df, n_probe_triplets=60)
202
+
203
+ shd_row = {}
204
+ nshd_row = {}
205
+ max_possible = n_vars * (n_vars - 1) // 2
206
+
207
+ for algo_name in ALGO_NAMES:
208
+ result = run_algorithm(algo_name, df, timeout_sec=timeout)
209
+ metrics = evaluate_algorithm_result(result, true_cpdag)
210
+ shd_row[algo_name] = metrics['shd']
211
+ nshd_row[algo_name] = metrics['normalized_shd']
212
+
213
+ feat_row = {name: features.get(name, 0.0) for name in FEATURE_NAMES}
214
+ config = {
215
+ 'network': net_label,
216
+ 'n_samples': n_samples,
217
+ 'seed': seed,
218
+ 'n_variables': n_vars,
219
+ 'n_true_edges': int(((true_cpdag + true_cpdag.T) > 0).sum() // 2),
220
+ }
221
+
222
+ # Log best algo
223
+ best = min(shd_row, key=shd_row.get)
224
+ logger.info(f" Best: {best} SHD={shd_row[best]}")
225
+
226
+ return feat_row, shd_row, nshd_row, config
227
+
228
+ except Exception as e:
229
+ logger.error(f" Failed: {e}")
230
+ return None, None, None, None
231
+
232
+
233
+ # ==============================================================
234
+ # PAIRWISE RANKING MODEL
235
+ # ==============================================================
236
+
237
+ def train_pairwise_ranking(X, Y_nshd, configs):
238
+ """Train pairwise ranking classifiers: for each (algo_i, algo_j) pair,
239
+ train a classifier to predict whether algo_i beats algo_j.
240
+
241
+ At inference: count wins for each algorithm, rank by win count.
242
+ """
243
+ n_algos = len(ALGO_NAMES)
244
+ scaler = StandardScaler()
245
+ X_scaled = scaler.fit_transform(X)
246
+
247
+ pair_models = {}
248
+ pair_accuracies = {}
249
+
250
+ for i in range(n_algos):
251
+ for j in range(i+1, n_algos):
252
+ algo_i, algo_j = ALGO_NAMES[i], ALGO_NAMES[j]
253
+
254
+ # Label: 1 if algo_i has lower nSHD (better) than algo_j
255
+ y = (Y_nshd.iloc[:, i] < Y_nshd.iloc[:, j]).astype(int).values
256
+
257
+ # Skip if one always wins
258
+ if y.mean() == 0 or y.mean() == 1:
259
+ pair_models[(i,j)] = None
260
+ pair_accuracies[(i,j)] = y.mean()
261
+ continue
262
+
263
+ clf = GradientBoostingClassifier(
264
+ n_estimators=200, max_depth=3, learning_rate=0.05,
265
+ random_state=42
266
+ )
267
+ clf.fit(X_scaled, y)
268
+
269
+ train_acc = clf.score(X_scaled, y)
270
+ pair_models[(i,j)] = clf
271
+ pair_accuracies[(i,j)] = train_acc
272
+
273
+ return pair_models, scaler, pair_accuracies
274
+
275
+
276
+ def predict_pairwise_ranking(pair_models, scaler, X_new, k=3):
277
+ """Use pairwise models to rank algorithms via win-count."""
278
+ X_scaled = scaler.transform(X_new)
279
+ n_algos = len(ALGO_NAMES)
280
+ n_samples = X_scaled.shape[0]
281
+
282
+ results = []
283
+ for idx in range(n_samples):
284
+ wins = np.zeros(n_algos)
285
+ x = X_scaled[idx:idx+1]
286
+
287
+ for i in range(n_algos):
288
+ for j in range(i+1, n_algos):
289
+ model = pair_models.get((i,j))
290
+ if model is None:
291
+ continue
292
+ pred = model.predict(x)[0]
293
+ if pred == 1: # algo_i wins
294
+ wins[i] += 1
295
+ else:
296
+ wins[j] += 1
297
+
298
+ ranking = np.argsort(-wins) # most wins first
299
+ results.append(ranking[:k])
300
+
301
+ return np.array(results)
302
+
303
+
304
+ def evaluate_pairwise_lono(X, Y_nshd, configs, k=3):
305
+ """LONO-CV for pairwise ranking model."""
306
+ networks = configs['network'].values
307
+ unique_nets = sorted(configs['network'].unique())
308
+ # For augmented data, group by base network name
309
+ base_nets = configs['network'].apply(lambda x: x.split('_')[0]).values
310
+ unique_base = sorted(set(base_nets))
311
+
312
+ top_k_hits = 0
313
+ regrets = []
314
+ total = 0
315
+
316
+ for test_base in unique_base:
317
+ test_mask = base_nets == test_base
318
+ train_mask = ~test_mask
319
+
320
+ if train_mask.sum() < 5 or test_mask.sum() == 0:
321
+ continue
322
+
323
+ X_train = X.values[train_mask]
324
+ Y_train = Y_nshd[train_mask]
325
+ X_test = X.values[test_mask]
326
+ Y_test = Y_nshd.values[test_mask]
327
+
328
+ # Train pairwise models
329
+ scaler = StandardScaler()
330
+ X_train_s = scaler.fit_transform(X_train)
331
+
332
+ n_algos = len(ALGO_NAMES)
333
+ pair_models = {}
334
+
335
+ for i in range(n_algos):
336
+ for j in range(i+1, n_algos):
337
+ y = (Y_train.iloc[:, i] < Y_train.iloc[:, j]).astype(int).values
338
+ if y.mean() == 0 or y.mean() == 1:
339
+ pair_models[(i,j)] = None
340
+ continue
341
+ clf = GradientBoostingClassifier(
342
+ n_estimators=100, max_depth=3, learning_rate=0.05,
343
+ random_state=42
344
+ )
345
+ clf.fit(X_train_s, y)
346
+ pair_models[(i,j)] = clf
347
+
348
+ # Predict
349
+ X_test_s = scaler.transform(X_test)
350
+
351
+ for idx in range(len(X_test_s)):
352
+ wins = np.zeros(n_algos)
353
+ x = X_test_s[idx:idx+1]
354
+
355
+ for i in range(n_algos):
356
+ for j in range(i+1, n_algos):
357
+ m = pair_models.get((i,j))
358
+ if m is None:
359
+ continue
360
+ if m.predict(x)[0] == 1:
361
+ wins[i] += 1
362
+ else:
363
+ wins[j] += 1
364
+
365
+ pred_top_k = np.argsort(-wins)[:k]
366
+ true_best = np.argmin(Y_test[idx])
367
+
368
+ if true_best in pred_top_k:
369
+ top_k_hits += 1
370
+
371
+ oracle = Y_test[idx, true_best]
372
+ selected = min(Y_test[idx, a] for a in pred_top_k)
373
+ regrets.append(selected - oracle)
374
+ total += 1
375
+
376
+ hit_rate = top_k_hits / total if total > 0 else 0
377
+ mean_regret = np.mean(regrets) if regrets else 0
378
+
379
+ return {
380
+ 'top_k_hit_rate': hit_rate,
381
+ 'mean_regret': mean_regret,
382
+ 'median_regret': np.median(regrets) if regrets else 0,
383
+ 'n_evaluated': total,
384
+ }
385
+
386
+
387
+ # ==============================================================
388
+ # MAIN
389
+ # ==============================================================
390
+
391
+ if __name__ == '__main__':
392
+ start_time = time.time()
393
+
394
+ # Step 1: Augment
395
+ print("="*80)
396
+ print("STEP 1: DATA AUGMENTATION")
397
+ print("="*80)
398
+
399
+ feats, shds, nshds, cfgs = augment_all(
400
+ n_varsub=2, drop_frac=0.3,
401
+ n_samplesub=2,
402
+ )
403
+
404
+ print(f"\nGenerated {len(cfgs)} augmented configs")
405
+
406
+ # Merge with original
407
+ X_orig, Y_shd_orig, Y_nshd_orig, configs_orig = load_meta_dataset()
408
+
409
+ X_aug = pd.DataFrame(feats, columns=FEATURE_NAMES)
410
+ Y_shd_aug = pd.DataFrame(shds, columns=ALGO_NAMES)
411
+ Y_nshd_aug = pd.DataFrame(nshds, columns=ALGO_NAMES)
412
+ configs_aug = pd.DataFrame(cfgs)
413
+
414
+ X_all = pd.concat([X_orig, X_aug], ignore_index=True)
415
+ Y_shd_all = pd.concat([Y_shd_orig, Y_shd_aug], ignore_index=True)
416
+ Y_nshd_all = pd.concat([Y_nshd_orig, Y_nshd_aug], ignore_index=True)
417
+ configs_all = pd.concat([configs_orig, configs_aug], ignore_index=True)
418
+
419
+ print(f"Total dataset: {len(configs_all)} configs "
420
+ f"({len(configs_orig)} original + {len(configs_aug)} augmented)")
421
+
422
+ # Save augmented data
423
+ X_all.to_csv(os.path.join(RESULTS_DIR, 'meta_features.csv'), index=False)
424
+ Y_shd_all.to_csv(os.path.join(RESULTS_DIR, 'shd_matrix.csv'), index=False)
425
+ Y_nshd_all.to_csv(os.path.join(RESULTS_DIR, 'normalized_shd_matrix.csv'), index=False)
426
+ configs_all.to_csv(os.path.join(RESULTS_DIR, 'configs.csv'), index=False)
427
+
428
+ # Step 2: Model comparison
429
+ print("\n" + "="*80)
430
+ print("STEP 2: MODEL COMPARISON (LONO-CV)")
431
+ print("="*80)
432
+
433
+ # Reload augmented data
434
+ X, Y_shd, Y_nshd, configs = load_meta_dataset()
435
+
436
+ print(f"\n{'Model':25s} {'Top3Hit':>8s} {'NDCG@3':>8s} {'Regret':>8s}")
437
+ print("-"*55)
438
+
439
+ model_configs = [
440
+ ('RF-200', 'rf', {'n_estimators': 200}),
441
+ ('RF-500', 'rf', {'n_estimators': 500}),
442
+ ('GBM-500-lr05', 'gbm', {'n_estimators': 500, 'max_depth': 3, 'learning_rate': 0.05}),
443
+ ('GBM-300-lr01', 'gbm', {'n_estimators': 300, 'max_depth': 4, 'learning_rate': 0.01}),
444
+ ('GBM-200-lr1', 'gbm', {'n_estimators': 200, 'max_depth': 5, 'learning_rate': 0.1}),
445
+ ]
446
+
447
+ best_hit = 0
448
+ best_config = None
449
+
450
+ for name, mtype, kwargs in model_configs:
451
+ r = evaluate_lono_cv(X, Y_nshd, configs, model_type=mtype, k=3, **kwargs)
452
+ o = r['overall']
453
+ print(f"{name:25s} {o['top_k_hit_rate']:8.3f} {o['ndcg_at_k']:8.3f} {o['mean_regret']:8.4f}")
454
+ if o['top_k_hit_rate'] > best_hit:
455
+ best_hit = o['top_k_hit_rate']
456
+ best_config = (name, mtype, kwargs, o)
457
+
458
+ # Pairwise ranking
459
+ print(f"\n{'Pairwise-GBM':25s}", end="")
460
+ pw_results = evaluate_pairwise_lono(X, Y_nshd, configs, k=3)
461
+ print(f" {pw_results['top_k_hit_rate']:8.3f} {'N/A':>8s} {pw_results['mean_regret']:8.4f}")
462
+
463
+ if pw_results['top_k_hit_rate'] > best_hit:
464
+ best_hit = pw_results['top_k_hit_rate']
465
+ best_config = ('Pairwise-GBM', 'pairwise', {}, pw_results)
466
+
467
+ print(f"\n{'='*55}")
468
+ print(f"BEST MODEL: {best_config[0]} (hit rate={best_hit:.3f})")
469
+ print(f"{'='*55}")
470
+
471
+ # Train & save best multi-output model
472
+ if best_config[1] != 'pairwise':
473
+ model, scaler = train_meta_learner(X, Y_nshd,
474
+ model_type=best_config[1],
475
+ **best_config[2])
476
+ save_model(model, scaler)
477
+
478
+ avg_imp, _ = get_feature_importance(model)
479
+ print("\nTop 10 Features:")
480
+ for feat, imp in sorted(avg_imp.items(), key=lambda x: -x[1])[:10]:
481
+ print(f" {feat:30s}: {imp:.4f}")
482
+ else:
483
+ # Save pairwise model separately
484
+ print("Pairwise model is best - training final version...")
485
+ pair_models, scaler, _ = train_pairwise_ranking(X, Y_nshd, configs)
486
+ os.makedirs('/app/causal_selection/models', exist_ok=True)
487
+ joblib.dump({'pair_models': pair_models, 'scaler': scaler},
488
+ '/app/causal_selection/models/pairwise_model.pkl')
489
+ # Also train and save best multi-output as fallback
490
+ best_mo = [c for c in model_configs if c[0] != 'Pairwise-GBM']
491
+ best_mo_hit = 0
492
+ best_mo_cfg = model_configs[0]
493
+ for name, mtype, kwargs in model_configs:
494
+ r = evaluate_lono_cv(X, Y_nshd, configs, model_type=mtype, k=3, **kwargs)
495
+ if r['overall']['top_k_hit_rate'] > best_mo_hit:
496
+ best_mo_hit = r['overall']['top_k_hit_rate']
497
+ best_mo_cfg = (name, mtype, kwargs)
498
+ model, scaler = train_meta_learner(X, Y_nshd, model_type=best_mo_cfg[1], **best_mo_cfg[2])
499
+ save_model(model, scaler)
500
+
501
+ elapsed = time.time() - start_time
502
+ print(f"\nTotal time: {elapsed/60:.1f} minutes")
503
+
504
+ # Save summary
505
+ summary = {
506
+ 'n_configs_original': int(len(configs_orig)),
507
+ 'n_configs_augmented': int(len(configs_aug)),
508
+ 'n_configs_total': int(len(configs_all)),
509
+ 'best_model': best_config[0],
510
+ 'best_top3_hit_rate': float(best_hit),
511
+ 'best_metrics': {k: float(v) if isinstance(v, (float, np.floating)) else v
512
+ for k, v in best_config[3].items()},
513
+ }
514
+ with open(os.path.join(RESULTS_DIR, 'improvement_summary.json'), 'w') as f:
515
+ json.dump(summary, f, indent=2)
516
+
517
+ print(f"\nSummary saved to {RESULTS_DIR}/improvement_summary.json")