Oguzz07 commited on
Commit
603c777
·
verified ·
1 Parent(s): 78e2a75

Add causal_selection/discovery/evaluator.py

Browse files
causal_selection/discovery/evaluator.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation module: compute SHD, F1, Precision, Recall between predicted and true CPDAGs.
3
+ """
4
+ import numpy as np
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def compute_shd(pred_adj, true_adj):
11
+ """Compute Structural Hamming Distance between two CPDAGs/DAGs.
12
+
13
+ Both inputs are adjacency matrices where:
14
+ adj[i,j]=1 and adj[j,i]=0 means i->j (directed)
15
+ adj[i,j]=1 and adj[j,i]=1 means i--j (undirected)
16
+
17
+ SHD counts: missing edges + extra edges + wrongly oriented edges
18
+ """
19
+ n = pred_adj.shape[0]
20
+ assert pred_adj.shape == true_adj.shape, "Adjacency matrices must have same shape"
21
+
22
+ shd = 0
23
+ for i in range(n):
24
+ for j in range(i + 1, n):
25
+ # True edge state
26
+ t_ij, t_ji = true_adj[i, j], true_adj[j, i]
27
+ # Predicted edge state
28
+ p_ij, p_ji = pred_adj[i, j], pred_adj[j, i]
29
+
30
+ true_has_edge = (t_ij == 1 or t_ji == 1)
31
+ pred_has_edge = (p_ij == 1 or p_ji == 1)
32
+
33
+ if true_has_edge and not pred_has_edge:
34
+ # Missing edge
35
+ shd += 1
36
+ elif not true_has_edge and pred_has_edge:
37
+ # Extra edge
38
+ shd += 1
39
+ elif true_has_edge and pred_has_edge:
40
+ # Both have edge - check if same type
41
+ true_type = (t_ij, t_ji) # (1,0)=directed, (1,1)=undirected, (0,1)=reverse
42
+ pred_type = (p_ij, p_ji)
43
+ if true_type != pred_type:
44
+ # Wrong orientation
45
+ shd += 1
46
+
47
+ return shd
48
+
49
+
50
+ def compute_edge_metrics(pred_adj, true_adj):
51
+ """Compute precision, recall, F1 on edges (skeleton-level and directed).
52
+
53
+ Returns dict with:
54
+ - skeleton_precision, skeleton_recall, skeleton_f1: ignoring direction
55
+ - directed_precision, directed_recall, directed_f1: including direction
56
+ - shd: structural hamming distance
57
+ - n_true_edges, n_pred_edges: edge counts
58
+ """
59
+ n = pred_adj.shape[0]
60
+
61
+ # Skeleton comparison (ignoring direction)
62
+ true_skeleton = ((true_adj + true_adj.T) > 0).astype(int)
63
+ pred_skeleton = ((pred_adj + pred_adj.T) > 0).astype(int)
64
+
65
+ # Only upper triangle for skeleton (undirected)
66
+ skel_tp = skel_fp = skel_fn = 0
67
+ for i in range(n):
68
+ for j in range(i + 1, n):
69
+ t = true_skeleton[i, j]
70
+ p = pred_skeleton[i, j]
71
+ if t == 1 and p == 1:
72
+ skel_tp += 1
73
+ elif t == 0 and p == 1:
74
+ skel_fp += 1
75
+ elif t == 1 and p == 0:
76
+ skel_fn += 1
77
+
78
+ skel_precision = skel_tp / (skel_tp + skel_fp) if (skel_tp + skel_fp) > 0 else 0
79
+ skel_recall = skel_tp / (skel_tp + skel_fn) if (skel_tp + skel_fn) > 0 else 0
80
+ skel_f1 = (2 * skel_precision * skel_recall / (skel_precision + skel_recall)
81
+ if (skel_precision + skel_recall) > 0 else 0)
82
+
83
+ # Directed comparison (full adjacency)
84
+ dir_tp = dir_fp = dir_fn = 0
85
+ for i in range(n):
86
+ for j in range(n):
87
+ if i == j:
88
+ continue
89
+ t = true_adj[i, j]
90
+ p = pred_adj[i, j]
91
+ if t == 1 and p == 1:
92
+ dir_tp += 1
93
+ elif t == 0 and p == 1:
94
+ dir_fp += 1
95
+ elif t == 1 and p == 0:
96
+ dir_fn += 1
97
+
98
+ dir_precision = dir_tp / (dir_tp + dir_fp) if (dir_tp + dir_fp) > 0 else 0
99
+ dir_recall = dir_tp / (dir_tp + dir_fn) if (dir_tp + dir_fn) > 0 else 0
100
+ dir_f1 = (2 * dir_precision * dir_recall / (dir_precision + dir_recall)
101
+ if (dir_precision + dir_recall) > 0 else 0)
102
+
103
+ shd = compute_shd(pred_adj, true_adj)
104
+
105
+ # Count edges
106
+ n_true_edges = 0
107
+ n_pred_edges = 0
108
+ for i in range(n):
109
+ for j in range(i + 1, n):
110
+ if true_adj[i, j] or true_adj[j, i]:
111
+ n_true_edges += 1
112
+ if pred_adj[i, j] or pred_adj[j, i]:
113
+ n_pred_edges += 1
114
+
115
+ return {
116
+ 'shd': shd,
117
+ 'skeleton_precision': skel_precision,
118
+ 'skeleton_recall': skel_recall,
119
+ 'skeleton_f1': skel_f1,
120
+ 'directed_precision': dir_precision,
121
+ 'directed_recall': dir_recall,
122
+ 'directed_f1': dir_f1,
123
+ 'n_true_edges': n_true_edges,
124
+ 'n_pred_edges': n_pred_edges,
125
+ }
126
+
127
+
128
+ def dag_to_cpdag(dag_adjmat):
129
+ """Import from data.generator to avoid circular dependency."""
130
+ from causal_selection.data.generator import dag_to_cpdag as _dag_to_cpdag
131
+ return _dag_to_cpdag(dag_adjmat)
132
+
133
+
134
+ def evaluate_algorithm_result(result, true_cpdag):
135
+ """Evaluate a single algorithm result against ground truth CPDAG.
136
+
137
+ Args:
138
+ result: dict from run_algorithm (must have 'adjmat', 'output_type', 'status')
139
+ true_cpdag: ground truth CPDAG adjacency matrix
140
+
141
+ Returns:
142
+ dict with all metrics, or penalty metrics if algorithm failed
143
+ """
144
+ n = true_cpdag.shape[0]
145
+ max_possible_shd = n * (n - 1) // 2 # maximum possible SHD
146
+
147
+ if result['status'] != 'success' or result['adjmat'] is None:
148
+ return {
149
+ 'shd': max_possible_shd,
150
+ 'normalized_shd': 1.0,
151
+ 'skeleton_precision': 0.0,
152
+ 'skeleton_recall': 0.0,
153
+ 'skeleton_f1': 0.0,
154
+ 'directed_precision': 0.0,
155
+ 'directed_recall': 0.0,
156
+ 'directed_f1': 0.0,
157
+ 'n_true_edges': int(((true_cpdag + true_cpdag.T) > 0).sum() // 2),
158
+ 'n_pred_edges': 0,
159
+ 'runtime': result['runtime'],
160
+ 'status': result['status'],
161
+ }
162
+
163
+ pred_adj = result['adjmat']
164
+
165
+ # If the algorithm outputs a DAG, convert to CPDAG for fair comparison
166
+ if result['output_type'] == 'dag':
167
+ pred_cpdag = dag_to_cpdag(pred_adj)
168
+ else:
169
+ pred_cpdag = pred_adj # Already CPDAG or PAG-derived
170
+
171
+ # Compute metrics
172
+ metrics = compute_edge_metrics(pred_cpdag, true_cpdag)
173
+ metrics['normalized_shd'] = metrics['shd'] / max_possible_shd if max_possible_shd > 0 else 0
174
+ metrics['runtime'] = result['runtime']
175
+ metrics['status'] = result['status']
176
+
177
+ return metrics
178
+
179
+
180
+ if __name__ == '__main__':
181
+ # Test with Asia network
182
+ from causal_selection.data.generator import load_bn_model, get_true_dag_adjmat, dag_to_cpdag as gen_dag_to_cpdag, sample_dataset
183
+ from causal_selection.discovery.algorithms import run_algorithm, ALGORITHM_POOL
184
+ import warnings
185
+ warnings.filterwarnings('ignore')
186
+
187
+ model = load_bn_model('asia')
188
+ true_dag, nodes = get_true_dag_adjmat(model)
189
+ true_cpdag = gen_dag_to_cpdag(true_dag)
190
+ df = sample_dataset(model, 1000, seed=0)
191
+
192
+ print(f"ASIA (N=1000) - True edges: {int(((true_cpdag + true_cpdag.T) > 0).sum() // 2)}")
193
+ print(f"{'Algorithm':15s} {'SHD':>5s} {'nSHD':>6s} {'Skel_F1':>8s} {'Dir_F1':>7s} {'Runtime':>8s} {'Status'}")
194
+ print("-" * 70)
195
+
196
+ for algo_name in ALGORITHM_POOL:
197
+ result = run_algorithm(algo_name, df, timeout_sec=60)
198
+ metrics = evaluate_algorithm_result(result, true_cpdag)
199
+ print(f"{algo_name:15s} {metrics['shd']:5d} {metrics['normalized_shd']:6.3f} "
200
+ f"{metrics['skeleton_f1']:8.3f} {metrics['directed_f1']:7.3f} "
201
+ f"{metrics['runtime']:7.2f}s {metrics['status']}")