Update README with final results (71.3% hit rate, 178 configs, pairwise ranking)
Browse files
README.md
CHANGED
|
@@ -6,33 +6,50 @@ A meta-learning system that predicts the **top-3 best causal discovery algorithm
|
|
| 6 |
|
| 7 |
Given a new discrete dataset (pandas DataFrame), the system:
|
| 8 |
1. **Extracts 34 meta-features** (entropy, mutual information, chi² statistics, CI test probes, etc.)
|
| 9 |
-
2. **Predicts normalized SHD** for each of 9 algorithms via
|
| 10 |
3. **Ranks and returns the top-3** algorithms expected to produce the most accurate CPDAG
|
| 11 |
|
| 12 |
## 📊 Performance (Leave-One-Network-Out Cross-Validation)
|
| 13 |
|
|
|
|
|
|
|
| 14 |
| Metric | Value |
|
| 15 |
|--------|-------|
|
| 16 |
-
| **Top-3 Hit Rate** | **
|
| 17 |
-
| **
|
| 18 |
-
| **Mean Regret** | **0.012** (tiny SHD gap vs oracle selection) |
|
| 19 |
| **Median Regret** | **0.000** (majority of predictions are perfect) |
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
## 🧪 Algorithm Pool (9 algorithms)
|
| 24 |
|
| 25 |
-
| Algorithm | Family | Library | Output |
|
| 26 |
-
|-----------|--------|---------|--------|
|
| 27 |
-
| **
|
| 28 |
-
| **
|
| 29 |
-
| **
|
| 30 |
-
| **
|
| 31 |
-
| **
|
| 32 |
-
| **
|
| 33 |
-
| **
|
| 34 |
-
| **
|
| 35 |
-
| **
|
| 36 |
|
| 37 |
## 🔬 Key Insight: Dependency Parsing Connection
|
| 38 |
|
|
@@ -43,11 +60,16 @@ This project was inspired by a structural parallel between **NLP dependency pars
|
|
| 43 |
|
| 44 |
The biaffine pairwise scoring mechanism from Dozat & Manning (2017) was independently reinvented by AVICI and CauScale for causal structure learning — validating this connection.
|
| 45 |
|
| 46 |
-
|
| 47 |
-
1. `
|
| 48 |
-
2. `
|
| 49 |
-
3. `
|
| 50 |
-
4. `
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
## 🚀 Quick Start
|
| 53 |
|
|
@@ -80,17 +102,24 @@ causal_selection/
|
|
| 80 |
│ ├── trainer.py # Multi-Output RF/GBM + LONO-CV evaluation
|
| 81 |
│ └── predictor.py # Inference: dataset → top-3 prediction
|
| 82 |
├── models/
|
| 83 |
-
│ ├── meta_learner.pkl # Trained
|
|
|
|
| 84 |
│ └── scaler.pkl # Feature scaler
|
| 85 |
├── benchmark.py # Full benchmark orchestration
|
| 86 |
-
|
|
|
|
| 87 |
```
|
| 88 |
|
| 89 |
## 📈 Benchmark Data
|
| 90 |
|
| 91 |
- **14 bnlearn networks**: asia, cancer, earthquake, sachs, survey, alarm, barley, child, insurance, mildew, water, hailfinder, hepar2, win95pts
|
| 92 |
-
- **
|
| 93 |
-
- **1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
## 🔧 Dependencies
|
| 96 |
|
|
@@ -108,10 +137,18 @@ joblib
|
|
| 108 |
|
| 109 |
- **Causal-Copilot** (arxiv:2504.13263) — Closest existing algorithm selection system
|
| 110 |
- **AVICI** (arxiv:2205.12934) — Amortized causal structure learning (biaffine architecture)
|
|
|
|
| 111 |
- **Dozat & Manning** (arxiv:1611.01734) — Deep Biaffine Attention for dependency parsing
|
|
|
|
| 112 |
- **SATzilla** (arxiv:1401.2474) — Algorithm selection via meta-learning
|
| 113 |
- **bnlearn** (bnlearn.com) — Bayesian network benchmark repository
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
## License
|
| 116 |
|
| 117 |
MIT
|
|
|
|
| 6 |
|
| 7 |
Given a new discrete dataset (pandas DataFrame), the system:
|
| 8 |
1. **Extracts 34 meta-features** (entropy, mutual information, chi² statistics, CI test probes, etc.)
|
| 9 |
+
2. **Predicts normalized SHD** for each of 9 algorithms via trained models
|
| 10 |
3. **Ranks and returns the top-3** algorithms expected to produce the most accurate CPDAG
|
| 11 |
|
| 12 |
## 📊 Performance (Leave-One-Network-Out Cross-Validation)
|
| 13 |
|
| 14 |
+
### Best Model: Pairwise-GBM Ranking
|
| 15 |
+
|
| 16 |
| Metric | Value |
|
| 17 |
|--------|-------|
|
| 18 |
+
| **Top-3 Hit Rate** | **71.3%** (true best algorithm is in predicted top-3) |
|
| 19 |
+
| **Mean Regret** | **0.011** (tiny SHD gap vs oracle selection) |
|
|
|
|
| 20 |
| **Median Regret** | **0.000** (majority of predictions are perfect) |
|
| 21 |
|
| 22 |
+
### Model Comparison (178 configs, 14 networks + augmented)
|
| 23 |
+
|
| 24 |
+
| Model | Top-3 Hit Rate | NDCG@3 | Mean Regret |
|
| 25 |
+
|-------|---------------|--------|-------------|
|
| 26 |
+
| **Pairwise-GBM** | **71.3%** | — | 0.011 |
|
| 27 |
+
| GBM-300-lr01 | 67.4% | 0.957 | 0.011 |
|
| 28 |
+
| RF-200 | 66.9% | 0.961 | 0.007 |
|
| 29 |
+
| RF-500 | 66.3% | 0.962 | 0.007 |
|
| 30 |
+
| GBM-500-lr05 | 65.2% | 0.948 | 0.013 |
|
| 31 |
+
|
| 32 |
+
### Progression
|
| 33 |
+
|
| 34 |
+
| Stage | Configs | Networks | Top-3 Hit Rate |
|
| 35 |
+
|-------|---------|----------|---------------|
|
| 36 |
+
| Initial (small nets) | 65 | 4 | 68.2% |
|
| 37 |
+
| All 14 networks | 122 | 14 | 70.5% |
|
| 38 |
+
| + Data augmentation | 178 | 14+aug | **71.3%** |
|
| 39 |
|
| 40 |
## 🧪 Algorithm Pool (9 algorithms)
|
| 41 |
|
| 42 |
+
| Algorithm | Family | Library | Output | Wins |
|
| 43 |
+
|-----------|--------|---------|--------|------|
|
| 44 |
+
| **GES** | Score-based | causal-learn | CPDAG | 47% |
|
| 45 |
+
| **PC** | Constraint-based | causal-learn | CPDAG | 32% |
|
| 46 |
+
| **FCI** | Constraint-based | causal-learn | PAG | 8% |
|
| 47 |
+
| **K2** | Score-based | pgmpy | DAG | 6% |
|
| 48 |
+
| **HC** | Score-based (greedy) | pgmpy | DAG | 3% |
|
| 49 |
+
| **Tabu** | Score-based (meta) | pgmpy | DAG | 2% |
|
| 50 |
+
| **GRaSP** | Permutation-based | causal-learn | CPDAG | 1% |
|
| 51 |
+
| **BOSS** | Permutation-based | causal-learn | CPDAG | 1% |
|
| 52 |
+
| **MMHC** | Hybrid | pgmpy | DAG | <1% |
|
| 53 |
|
| 54 |
## 🔬 Key Insight: Dependency Parsing Connection
|
| 55 |
|
|
|
|
| 60 |
|
| 61 |
The biaffine pairwise scoring mechanism from Dozat & Manning (2017) was independently reinvented by AVICI and CauScale for causal structure learning — validating this connection.
|
| 62 |
|
| 63 |
+
### Top Predictive Meta-Features
|
| 64 |
+
1. `n_variables` (30%) — network size (how many nodes in the graph)
|
| 65 |
+
2. `max_pairwise_MI` (24%) — strongest pairwise dependency (≈ biaffine arc score)
|
| 66 |
+
3. `max_cramers_v` (8%) — strongest association strength
|
| 67 |
+
4. `max_entropy` (7%) — variable complexity
|
| 68 |
+
|
| 69 |
+
### Three Ideas Borrowed from Parsing
|
| 70 |
+
1. **Biaffine-style pairwise features**: MI and Cramér's V between all variable pairs = parsing's arc scores
|
| 71 |
+
2. **Pairwise ranking** (our best model): For each algorithm pair (A,B), predict which wins → count wins to rank. Inspired by pairwise tournament-style parser selection
|
| 72 |
+
3. **Cross-domain transfer**: Train on well-characterized bnlearn networks → predict on new unseen datasets (= cross-lingual parser transfer)
|
| 73 |
|
| 74 |
## 🚀 Quick Start
|
| 75 |
|
|
|
|
| 102 |
│ ├── trainer.py # Multi-Output RF/GBM + LONO-CV evaluation
|
| 103 |
│ └── predictor.py # Inference: dataset → top-3 prediction
|
| 104 |
├── models/
|
| 105 |
+
│ ├── meta_learner.pkl # Trained GBM (multi-output fallback)
|
| 106 |
+
│ ├── pairwise_model.pkl # Pairwise ranking GBM (best model)
|
| 107 |
│ └── scaler.pkl # Feature scaler
|
| 108 |
├── benchmark.py # Full benchmark orchestration
|
| 109 |
+
├── run_benchmark.py # Resumable benchmark runner
|
| 110 |
+
└── augment_and_improve.py # Data augmentation + model improvement
|
| 111 |
```
|
| 112 |
|
| 113 |
## 📈 Benchmark Data
|
| 114 |
|
| 115 |
- **14 bnlearn networks**: asia, cancer, earthquake, sachs, survey, alarm, barley, child, insurance, mildew, water, hailfinder, hepar2, win95pts
|
| 116 |
+
- **178 dataset configs**: 122 original + 56 augmented (variable subsampling, sample-size variation, noise injection)
|
| 117 |
+
- **1,600+ algorithm runs**: 9 algorithms × 178 configs with per-algorithm timeout
|
| 118 |
+
|
| 119 |
+
### Data Augmentation Strategies
|
| 120 |
+
- **Variable subsampling**: Drop 20-40% of variables to create virtual sub-networks
|
| 121 |
+
- **Sample-size variation**: Generate N=300, 750, 1500, 3000 for each network
|
| 122 |
+
- **Noise injection**: Randomly flip 5-10% of categorical values
|
| 123 |
|
| 124 |
## 🔧 Dependencies
|
| 125 |
|
|
|
|
| 137 |
|
| 138 |
- **Causal-Copilot** (arxiv:2504.13263) — Closest existing algorithm selection system
|
| 139 |
- **AVICI** (arxiv:2205.12934) — Amortized causal structure learning (biaffine architecture)
|
| 140 |
+
- **CauScale** (arxiv:2602.08629) — Scalable neural causal discovery
|
| 141 |
- **Dozat & Manning** (arxiv:1611.01734) — Deep Biaffine Attention for dependency parsing
|
| 142 |
+
- **TreeCRF** (arxiv:2005.00975) — Global structural training loss for parsing
|
| 143 |
- **SATzilla** (arxiv:1401.2474) — Algorithm selection via meta-learning
|
| 144 |
- **bnlearn** (bnlearn.com) — Bayesian network benchmark repository
|
| 145 |
|
| 146 |
+
## 🔮 Future Work (Phase 2)
|
| 147 |
+
1. **Biaffine neural encoder**: Pre-train a neural feature extractor that learns variable-pair "arc scores"
|
| 148 |
+
2. **Portfolio regret loss** (TreeCRF-inspired): Global ranking optimization instead of per-algorithm MSE
|
| 149 |
+
3. **Hyperparameter co-selection**: Predict not just which algorithm but optimal hyperparameters (CASH)
|
| 150 |
+
4. **Ensemble prediction**: Run top-3 and vote on edges across their CPDAGs
|
| 151 |
+
|
| 152 |
## License
|
| 153 |
|
| 154 |
MIT
|