Oguzz07 commited on
Commit
3b69a24
·
verified ·
1 Parent(s): 603c777

Add causal_selection/features/extractor.py

Browse files
causal_selection/features/extractor.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Meta-feature extraction module for discrete observational datasets.
3
+
4
+ Extracts ~34 features across 5 categories:
5
+ - Tier 1: Basic descriptors (6 features)
6
+ - Tier 2: Information-theoretic (8 features)
7
+ - Tier 3: Dependency structure (8 features)
8
+ - Tier 4: CI test landmark probes (6 features)
9
+ - Tier 5: Distribution shape (6 features)
10
+ """
11
+ import numpy as np
12
+ import pandas as pd
13
+ from scipy.stats import entropy, chi2_contingency
14
+ from itertools import combinations
15
+ import warnings
16
+ import logging
17
+
18
+ warnings.filterwarnings('ignore')
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def extract_all_features(df, n_probe_triplets=100, alpha=0.05):
23
+ """Extract all meta-features from a discrete dataset.
24
+
25
+ Args:
26
+ df: pd.DataFrame with integer-encoded discrete columns
27
+ n_probe_triplets: number of random (X,Y,Z) triplets for CI probes
28
+ alpha: significance level for dependency tests
29
+
30
+ Returns:
31
+ dict of feature_name -> float
32
+ """
33
+ features = {}
34
+
35
+ # Tier 1: Basic descriptors
36
+ features.update(_basic_features(df))
37
+
38
+ # Tier 2: Information-theoretic
39
+ features.update(_info_theory_features(df))
40
+
41
+ # Tier 3: Dependency structure
42
+ features.update(_dependency_features(df, alpha=alpha))
43
+
44
+ # Tier 4: CI test landmark probes
45
+ features.update(_ci_probe_features(df, n_probes=n_probe_triplets, alpha=alpha))
46
+
47
+ # Tier 5: Distribution shape
48
+ features.update(_distribution_features(df))
49
+
50
+ return features
51
+
52
+
53
+ def _basic_features(df):
54
+ """Tier 1: Basic dataset descriptors."""
55
+ n_samples, n_vars = df.shape
56
+ cardinalities = df.nunique().values
57
+
58
+ return {
59
+ 'n_samples': n_samples,
60
+ 'n_variables': n_vars,
61
+ 'n_over_p': n_samples / max(n_vars, 1),
62
+ 'avg_cardinality': cardinalities.mean(),
63
+ 'max_cardinality': cardinalities.max(),
64
+ 'min_cardinality': cardinalities.min(),
65
+ }
66
+
67
+
68
+ def _info_theory_features(df):
69
+ """Tier 2: Information-theoretic features."""
70
+ n, p = df.shape
71
+
72
+ # Per-variable entropy
73
+ entropies = []
74
+ for col in df.columns:
75
+ vc = df[col].value_counts(normalize=True)
76
+ entropies.append(entropy(vc))
77
+ entropies = np.array(entropies)
78
+
79
+ # Pairwise mutual information (subsample if too many pairs)
80
+ cols = list(range(p))
81
+ all_pairs = list(combinations(cols, 2))
82
+
83
+ # Limit pairs for large datasets
84
+ max_pairs = min(len(all_pairs), 500)
85
+ if len(all_pairs) > max_pairs:
86
+ rng = np.random.RandomState(42)
87
+ pair_indices = rng.choice(len(all_pairs), max_pairs, replace=False)
88
+ pairs = [all_pairs[i] for i in pair_indices]
89
+ else:
90
+ pairs = all_pairs
91
+
92
+ mis = []
93
+ norm_mis = []
94
+ for i, j in pairs:
95
+ mi = _mutual_information(df.iloc[:, i].values, df.iloc[:, j].values)
96
+ mis.append(mi)
97
+
98
+ # Normalized MI
99
+ denom = np.sqrt(entropies[i] * entropies[j])
100
+ if denom > 1e-10:
101
+ norm_mis.append(mi / denom)
102
+ else:
103
+ norm_mis.append(0.0)
104
+
105
+ mis = np.array(mis)
106
+ norm_mis = np.array(norm_mis)
107
+
108
+ return {
109
+ 'mean_entropy': entropies.mean(),
110
+ 'std_entropy': entropies.std(),
111
+ 'max_entropy': entropies.max(),
112
+ 'mean_pairwise_MI': mis.mean(),
113
+ 'std_pairwise_MI': mis.std(),
114
+ 'max_pairwise_MI': mis.max(),
115
+ 'mean_normalized_MI': norm_mis.mean(),
116
+ 'frac_high_MI_pairs': (mis > 0.05).mean(), # threshold for "meaningful" MI
117
+ }
118
+
119
+
120
+ def _dependency_features(df, alpha=0.05):
121
+ """Tier 3: Dependency structure features via chi-squared tests."""
122
+ n, p = df.shape
123
+ cols = list(range(p))
124
+ all_pairs = list(combinations(cols, 2))
125
+
126
+ # Limit pairs
127
+ max_pairs = min(len(all_pairs), 500)
128
+ if len(all_pairs) > max_pairs:
129
+ rng = np.random.RandomState(42)
130
+ pair_indices = rng.choice(len(all_pairs), max_pairs, replace=False)
131
+ pairs = [all_pairs[i] for i in pair_indices]
132
+ else:
133
+ pairs = all_pairs
134
+
135
+ chi2_stats = []
136
+ pvals = []
137
+ cramers_vs = []
138
+
139
+ for i, j in pairs:
140
+ try:
141
+ ct = pd.crosstab(df.iloc[:, i], df.iloc[:, j])
142
+ if ct.shape[0] < 2 or ct.shape[1] < 2:
143
+ continue
144
+ chi2, pval, dof, expected = chi2_contingency(ct)
145
+ chi2_stats.append(chi2)
146
+ pvals.append(pval)
147
+
148
+ # Cramér's V
149
+ min_dim = min(ct.shape[0], ct.shape[1]) - 1
150
+ if min_dim > 0 and n > 0:
151
+ v = np.sqrt(chi2 / (n * min_dim))
152
+ cramers_vs.append(v)
153
+ except Exception:
154
+ continue
155
+
156
+ chi2_stats = np.array(chi2_stats) if chi2_stats else np.array([0.0])
157
+ pvals = np.array(pvals) if pvals else np.array([1.0])
158
+ cramers_vs = np.array(cramers_vs) if cramers_vs else np.array([0.0])
159
+
160
+ return {
161
+ 'density_proxy': (pvals < alpha).mean(),
162
+ 'mean_chi2_stat': chi2_stats.mean(),
163
+ 'std_chi2_stat': chi2_stats.std(),
164
+ 'max_chi2_stat': chi2_stats.max(),
165
+ 'mean_cramers_v': cramers_vs.mean(),
166
+ 'max_cramers_v': cramers_vs.max(),
167
+ 'frac_weak_deps': (cramers_vs < 0.1).mean(),
168
+ 'frac_strong_deps': (cramers_vs > 0.3).mean(),
169
+ }
170
+
171
+
172
+ def _ci_probe_features(df, n_probes=100, alpha=0.05):
173
+ """Tier 4: Conditional independence test landmark probes.
174
+
175
+ Sample random (X, Y, Z) triplets:
176
+ - Test X ⊥ Y (marginal)
177
+ - Test X ⊥ Y | Z (conditional)
178
+ Summarize test statistics.
179
+ """
180
+ n, p = df.shape
181
+
182
+ if p < 3:
183
+ return {
184
+ 'mean_pval_marginal': 0.5,
185
+ 'frac_dep_marginal': 0.5,
186
+ 'mean_pval_conditional': 0.5,
187
+ 'frac_dep_conditional': 0.5,
188
+ 'v_structure_proxy': 0.0,
189
+ 'faithfulness_proxy': 0.0,
190
+ }
191
+
192
+ rng = np.random.RandomState(42)
193
+ n_probes = min(n_probes, p * (p - 1) * (p - 2) // 6) # cap at actual triplets
194
+
195
+ pvals_marginal = []
196
+ pvals_conditional = []
197
+
198
+ for _ in range(n_probes):
199
+ try:
200
+ idxs = rng.choice(p, size=3, replace=False)
201
+ i, j, k = idxs
202
+
203
+ # Marginal test: X ⊥ Y
204
+ ct = pd.crosstab(df.iloc[:, i], df.iloc[:, j])
205
+ if ct.shape[0] >= 2 and ct.shape[1] >= 2:
206
+ _, pval, _, _ = chi2_contingency(ct)
207
+ pvals_marginal.append(pval)
208
+
209
+ # Conditional test: X ⊥ Y | Z
210
+ # Stratify by Z values
211
+ z_vals = df.iloc[:, k].unique()
212
+ cond_pvals = []
213
+ for z_val in z_vals:
214
+ mask = df.iloc[:, k] == z_val
215
+ if mask.sum() < 5:
216
+ continue
217
+ ct_cond = pd.crosstab(df.iloc[:, i][mask], df.iloc[:, j][mask])
218
+ if ct_cond.shape[0] >= 2 and ct_cond.shape[1] >= 2:
219
+ try:
220
+ _, pval_c, _, _ = chi2_contingency(ct_cond)
221
+ cond_pvals.append(pval_c)
222
+ except Exception:
223
+ pass
224
+
225
+ if cond_pvals:
226
+ # Use Fisher's method or mean p-value
227
+ pvals_conditional.append(np.mean(cond_pvals))
228
+ except Exception:
229
+ continue
230
+
231
+ pvals_marginal = np.array(pvals_marginal) if pvals_marginal else np.array([0.5])
232
+ pvals_conditional = np.array(pvals_conditional) if pvals_conditional else np.array([0.5])
233
+
234
+ frac_dep_m = (pvals_marginal < alpha).mean()
235
+ frac_dep_c = (pvals_conditional < alpha).mean()
236
+
237
+ return {
238
+ 'mean_pval_marginal': pvals_marginal.mean(),
239
+ 'frac_dep_marginal': frac_dep_m,
240
+ 'mean_pval_conditional': pvals_conditional.mean(),
241
+ 'frac_dep_conditional': frac_dep_c,
242
+ 'v_structure_proxy': frac_dep_m - frac_dep_c, # v-structures weaken conditional deps
243
+ 'faithfulness_proxy': abs(frac_dep_m - frac_dep_c), # divergence between marginal/conditional
244
+ }
245
+
246
+
247
+ def _distribution_features(df):
248
+ """Tier 5: Distribution shape features."""
249
+ n, p = df.shape
250
+
251
+ mode_freqs = []
252
+ balance_scores = []
253
+ cardinalities = []
254
+
255
+ for col in df.columns:
256
+ vc = df[col].value_counts(normalize=True)
257
+ mode_freqs.append(vc.iloc[0]) # frequency of most common value
258
+
259
+ card = len(vc)
260
+ cardinalities.append(card)
261
+
262
+ # Balance: entropy / log(cardinality) — 1.0 = perfectly uniform
263
+ if card > 1:
264
+ h = entropy(vc)
265
+ max_h = np.log(card)
266
+ balance_scores.append(h / max_h if max_h > 0 else 0)
267
+ else:
268
+ balance_scores.append(0.0)
269
+
270
+ mode_freqs = np.array(mode_freqs)
271
+ balance_scores = np.array(balance_scores)
272
+ cardinalities = np.array(cardinalities)
273
+
274
+ return {
275
+ 'mean_mode_frequency': mode_freqs.mean(),
276
+ 'std_mode_frequency': mode_freqs.std(),
277
+ 'mean_balance': balance_scores.mean(),
278
+ 'uniformity_score': balance_scores.mean(), # alias
279
+ 'frac_binary_vars': (cardinalities == 2).mean(),
280
+ 'frac_high_card_vars': (cardinalities > 5).mean(),
281
+ }
282
+
283
+
284
+ def _mutual_information(x, y):
285
+ """Compute mutual information between two discrete arrays."""
286
+ # Joint distribution
287
+ from collections import Counter
288
+ n = len(x)
289
+ joint = Counter(zip(x, y))
290
+ marginal_x = Counter(x)
291
+ marginal_y = Counter(y)
292
+
293
+ mi = 0.0
294
+ for (xi, yi), count in joint.items():
295
+ p_xy = count / n
296
+ p_x = marginal_x[xi] / n
297
+ p_y = marginal_y[yi] / n
298
+ if p_xy > 0 and p_x > 0 and p_y > 0:
299
+ mi += p_xy * np.log(p_xy / (p_x * p_y))
300
+
301
+ return max(mi, 0.0)
302
+
303
+
304
+ # Feature names for consistent ordering
305
+ FEATURE_NAMES = [
306
+ # Tier 1: Basic
307
+ 'n_samples', 'n_variables', 'n_over_p', 'avg_cardinality', 'max_cardinality', 'min_cardinality',
308
+ # Tier 2: Info-theoretic
309
+ 'mean_entropy', 'std_entropy', 'max_entropy', 'mean_pairwise_MI', 'std_pairwise_MI',
310
+ 'max_pairwise_MI', 'mean_normalized_MI', 'frac_high_MI_pairs',
311
+ # Tier 3: Dependency
312
+ 'density_proxy', 'mean_chi2_stat', 'std_chi2_stat', 'max_chi2_stat',
313
+ 'mean_cramers_v', 'max_cramers_v', 'frac_weak_deps', 'frac_strong_deps',
314
+ # Tier 4: CI probes
315
+ 'mean_pval_marginal', 'frac_dep_marginal', 'mean_pval_conditional',
316
+ 'frac_dep_conditional', 'v_structure_proxy', 'faithfulness_proxy',
317
+ # Tier 5: Distribution
318
+ 'mean_mode_frequency', 'std_mode_frequency', 'mean_balance', 'uniformity_score',
319
+ 'frac_binary_vars', 'frac_high_card_vars',
320
+ ]
321
+
322
+
323
+ def features_to_vector(features_dict):
324
+ """Convert feature dict to ordered numpy vector."""
325
+ return np.array([features_dict.get(name, 0.0) for name in FEATURE_NAMES])
326
+
327
+
328
+ if __name__ == '__main__':
329
+ logging.basicConfig(level=logging.INFO)
330
+
331
+ from causal_selection.data.generator import load_bn_model, sample_dataset
332
+
333
+ model = load_bn_model('asia')
334
+ df = sample_dataset(model, 1000, seed=0)
335
+
336
+ print(f"Extracting features from ASIA (N=1000)...")
337
+ features = extract_all_features(df)
338
+
339
+ for name in FEATURE_NAMES:
340
+ val = features.get(name, 'MISSING')
341
+ if isinstance(val, float):
342
+ print(f" {name:30s}: {val:10.4f}")
343
+ else:
344
+ print(f" {name:30s}: {val}")
345
+
346
+ print(f"\nTotal features: {len(FEATURE_NAMES)}")