Add causal_selection/discovery/evaluator.py
Browse files
causal_selection/discovery/evaluator.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation module: compute SHD, F1, Precision, Recall between predicted and true CPDAGs.
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def compute_shd(pred_adj, true_adj):
|
| 11 |
+
"""Compute Structural Hamming Distance between two CPDAGs/DAGs.
|
| 12 |
+
|
| 13 |
+
Both inputs are adjacency matrices where:
|
| 14 |
+
adj[i,j]=1 and adj[j,i]=0 means i->j (directed)
|
| 15 |
+
adj[i,j]=1 and adj[j,i]=1 means i--j (undirected)
|
| 16 |
+
|
| 17 |
+
SHD counts: missing edges + extra edges + wrongly oriented edges
|
| 18 |
+
"""
|
| 19 |
+
n = pred_adj.shape[0]
|
| 20 |
+
assert pred_adj.shape == true_adj.shape, "Adjacency matrices must have same shape"
|
| 21 |
+
|
| 22 |
+
shd = 0
|
| 23 |
+
for i in range(n):
|
| 24 |
+
for j in range(i + 1, n):
|
| 25 |
+
# True edge state
|
| 26 |
+
t_ij, t_ji = true_adj[i, j], true_adj[j, i]
|
| 27 |
+
# Predicted edge state
|
| 28 |
+
p_ij, p_ji = pred_adj[i, j], pred_adj[j, i]
|
| 29 |
+
|
| 30 |
+
true_has_edge = (t_ij == 1 or t_ji == 1)
|
| 31 |
+
pred_has_edge = (p_ij == 1 or p_ji == 1)
|
| 32 |
+
|
| 33 |
+
if true_has_edge and not pred_has_edge:
|
| 34 |
+
# Missing edge
|
| 35 |
+
shd += 1
|
| 36 |
+
elif not true_has_edge and pred_has_edge:
|
| 37 |
+
# Extra edge
|
| 38 |
+
shd += 1
|
| 39 |
+
elif true_has_edge and pred_has_edge:
|
| 40 |
+
# Both have edge - check if same type
|
| 41 |
+
true_type = (t_ij, t_ji) # (1,0)=directed, (1,1)=undirected, (0,1)=reverse
|
| 42 |
+
pred_type = (p_ij, p_ji)
|
| 43 |
+
if true_type != pred_type:
|
| 44 |
+
# Wrong orientation
|
| 45 |
+
shd += 1
|
| 46 |
+
|
| 47 |
+
return shd
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def compute_edge_metrics(pred_adj, true_adj):
|
| 51 |
+
"""Compute precision, recall, F1 on edges (skeleton-level and directed).
|
| 52 |
+
|
| 53 |
+
Returns dict with:
|
| 54 |
+
- skeleton_precision, skeleton_recall, skeleton_f1: ignoring direction
|
| 55 |
+
- directed_precision, directed_recall, directed_f1: including direction
|
| 56 |
+
- shd: structural hamming distance
|
| 57 |
+
- n_true_edges, n_pred_edges: edge counts
|
| 58 |
+
"""
|
| 59 |
+
n = pred_adj.shape[0]
|
| 60 |
+
|
| 61 |
+
# Skeleton comparison (ignoring direction)
|
| 62 |
+
true_skeleton = ((true_adj + true_adj.T) > 0).astype(int)
|
| 63 |
+
pred_skeleton = ((pred_adj + pred_adj.T) > 0).astype(int)
|
| 64 |
+
|
| 65 |
+
# Only upper triangle for skeleton (undirected)
|
| 66 |
+
skel_tp = skel_fp = skel_fn = 0
|
| 67 |
+
for i in range(n):
|
| 68 |
+
for j in range(i + 1, n):
|
| 69 |
+
t = true_skeleton[i, j]
|
| 70 |
+
p = pred_skeleton[i, j]
|
| 71 |
+
if t == 1 and p == 1:
|
| 72 |
+
skel_tp += 1
|
| 73 |
+
elif t == 0 and p == 1:
|
| 74 |
+
skel_fp += 1
|
| 75 |
+
elif t == 1 and p == 0:
|
| 76 |
+
skel_fn += 1
|
| 77 |
+
|
| 78 |
+
skel_precision = skel_tp / (skel_tp + skel_fp) if (skel_tp + skel_fp) > 0 else 0
|
| 79 |
+
skel_recall = skel_tp / (skel_tp + skel_fn) if (skel_tp + skel_fn) > 0 else 0
|
| 80 |
+
skel_f1 = (2 * skel_precision * skel_recall / (skel_precision + skel_recall)
|
| 81 |
+
if (skel_precision + skel_recall) > 0 else 0)
|
| 82 |
+
|
| 83 |
+
# Directed comparison (full adjacency)
|
| 84 |
+
dir_tp = dir_fp = dir_fn = 0
|
| 85 |
+
for i in range(n):
|
| 86 |
+
for j in range(n):
|
| 87 |
+
if i == j:
|
| 88 |
+
continue
|
| 89 |
+
t = true_adj[i, j]
|
| 90 |
+
p = pred_adj[i, j]
|
| 91 |
+
if t == 1 and p == 1:
|
| 92 |
+
dir_tp += 1
|
| 93 |
+
elif t == 0 and p == 1:
|
| 94 |
+
dir_fp += 1
|
| 95 |
+
elif t == 1 and p == 0:
|
| 96 |
+
dir_fn += 1
|
| 97 |
+
|
| 98 |
+
dir_precision = dir_tp / (dir_tp + dir_fp) if (dir_tp + dir_fp) > 0 else 0
|
| 99 |
+
dir_recall = dir_tp / (dir_tp + dir_fn) if (dir_tp + dir_fn) > 0 else 0
|
| 100 |
+
dir_f1 = (2 * dir_precision * dir_recall / (dir_precision + dir_recall)
|
| 101 |
+
if (dir_precision + dir_recall) > 0 else 0)
|
| 102 |
+
|
| 103 |
+
shd = compute_shd(pred_adj, true_adj)
|
| 104 |
+
|
| 105 |
+
# Count edges
|
| 106 |
+
n_true_edges = 0
|
| 107 |
+
n_pred_edges = 0
|
| 108 |
+
for i in range(n):
|
| 109 |
+
for j in range(i + 1, n):
|
| 110 |
+
if true_adj[i, j] or true_adj[j, i]:
|
| 111 |
+
n_true_edges += 1
|
| 112 |
+
if pred_adj[i, j] or pred_adj[j, i]:
|
| 113 |
+
n_pred_edges += 1
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
'shd': shd,
|
| 117 |
+
'skeleton_precision': skel_precision,
|
| 118 |
+
'skeleton_recall': skel_recall,
|
| 119 |
+
'skeleton_f1': skel_f1,
|
| 120 |
+
'directed_precision': dir_precision,
|
| 121 |
+
'directed_recall': dir_recall,
|
| 122 |
+
'directed_f1': dir_f1,
|
| 123 |
+
'n_true_edges': n_true_edges,
|
| 124 |
+
'n_pred_edges': n_pred_edges,
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def dag_to_cpdag(dag_adjmat):
|
| 129 |
+
"""Import from data.generator to avoid circular dependency."""
|
| 130 |
+
from causal_selection.data.generator import dag_to_cpdag as _dag_to_cpdag
|
| 131 |
+
return _dag_to_cpdag(dag_adjmat)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def evaluate_algorithm_result(result, true_cpdag):
|
| 135 |
+
"""Evaluate a single algorithm result against ground truth CPDAG.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
result: dict from run_algorithm (must have 'adjmat', 'output_type', 'status')
|
| 139 |
+
true_cpdag: ground truth CPDAG adjacency matrix
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
dict with all metrics, or penalty metrics if algorithm failed
|
| 143 |
+
"""
|
| 144 |
+
n = true_cpdag.shape[0]
|
| 145 |
+
max_possible_shd = n * (n - 1) // 2 # maximum possible SHD
|
| 146 |
+
|
| 147 |
+
if result['status'] != 'success' or result['adjmat'] is None:
|
| 148 |
+
return {
|
| 149 |
+
'shd': max_possible_shd,
|
| 150 |
+
'normalized_shd': 1.0,
|
| 151 |
+
'skeleton_precision': 0.0,
|
| 152 |
+
'skeleton_recall': 0.0,
|
| 153 |
+
'skeleton_f1': 0.0,
|
| 154 |
+
'directed_precision': 0.0,
|
| 155 |
+
'directed_recall': 0.0,
|
| 156 |
+
'directed_f1': 0.0,
|
| 157 |
+
'n_true_edges': int(((true_cpdag + true_cpdag.T) > 0).sum() // 2),
|
| 158 |
+
'n_pred_edges': 0,
|
| 159 |
+
'runtime': result['runtime'],
|
| 160 |
+
'status': result['status'],
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
pred_adj = result['adjmat']
|
| 164 |
+
|
| 165 |
+
# If the algorithm outputs a DAG, convert to CPDAG for fair comparison
|
| 166 |
+
if result['output_type'] == 'dag':
|
| 167 |
+
pred_cpdag = dag_to_cpdag(pred_adj)
|
| 168 |
+
else:
|
| 169 |
+
pred_cpdag = pred_adj # Already CPDAG or PAG-derived
|
| 170 |
+
|
| 171 |
+
# Compute metrics
|
| 172 |
+
metrics = compute_edge_metrics(pred_cpdag, true_cpdag)
|
| 173 |
+
metrics['normalized_shd'] = metrics['shd'] / max_possible_shd if max_possible_shd > 0 else 0
|
| 174 |
+
metrics['runtime'] = result['runtime']
|
| 175 |
+
metrics['status'] = result['status']
|
| 176 |
+
|
| 177 |
+
return metrics
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == '__main__':
|
| 181 |
+
# Test with Asia network
|
| 182 |
+
from causal_selection.data.generator import load_bn_model, get_true_dag_adjmat, dag_to_cpdag as gen_dag_to_cpdag, sample_dataset
|
| 183 |
+
from causal_selection.discovery.algorithms import run_algorithm, ALGORITHM_POOL
|
| 184 |
+
import warnings
|
| 185 |
+
warnings.filterwarnings('ignore')
|
| 186 |
+
|
| 187 |
+
model = load_bn_model('asia')
|
| 188 |
+
true_dag, nodes = get_true_dag_adjmat(model)
|
| 189 |
+
true_cpdag = gen_dag_to_cpdag(true_dag)
|
| 190 |
+
df = sample_dataset(model, 1000, seed=0)
|
| 191 |
+
|
| 192 |
+
print(f"ASIA (N=1000) - True edges: {int(((true_cpdag + true_cpdag.T) > 0).sum() // 2)}")
|
| 193 |
+
print(f"{'Algorithm':15s} {'SHD':>5s} {'nSHD':>6s} {'Skel_F1':>8s} {'Dir_F1':>7s} {'Runtime':>8s} {'Status'}")
|
| 194 |
+
print("-" * 70)
|
| 195 |
+
|
| 196 |
+
for algo_name in ALGORITHM_POOL:
|
| 197 |
+
result = run_algorithm(algo_name, df, timeout_sec=60)
|
| 198 |
+
metrics = evaluate_algorithm_result(result, true_cpdag)
|
| 199 |
+
print(f"{algo_name:15s} {metrics['shd']:5d} {metrics['normalized_shd']:6.3f} "
|
| 200 |
+
f"{metrics['skeleton_f1']:8.3f} {metrics['directed_f1']:7.3f} "
|
| 201 |
+
f"{metrics['runtime']:7.2f}s {metrics['status']}")
|