CausalBox / utils /casual_algorithms.py
ShutterStack's picture
major changes
ab66d4e verified
# utils/causal_algorithms.py
import networkx as nx
import pandas as pd
import numpy as np
from causallearn.search.ConstraintBased.PC import pc
# from causallearn.search.ScoreBased.GES import ges # Example import for GES
# from notears import notears_linear # Example import for NOTEARS
class CausalDiscoveryAlgorithms:
def pc_algorithm(self, df, alpha=0.05):
"""
Run PC algorithm to learn causal graph.
Returns a directed graph's adjacency matrix.
Requires numerical data.
"""
data_array = df.to_numpy()
cg = pc(data_array, alpha=alpha, indep_test="fisherz")
adj_matrix = cg.G.graph
return adj_matrix
def ges_algorithm(self, df):
"""
Placeholder for GES (Greedy Equivalence Search) algorithm.
Returns a directed graph's adjacency matrix.
You would implement or integrate the GES algorithm here.
"""
# Example: G, edges = ges(data_array)
# For now, returning a simplified correlation-based graph for demonstration
print("GES algorithm is a placeholder. Using a simplified correlation-based graph.")
G = nx.DiGraph()
nodes = df.columns
G.add_nodes_from(nodes)
corr_matrix = df.corr().abs()
threshold = 0.3
for i, col1 in enumerate(nodes):
for col2 in nodes[i+1:]:
if corr_matrix.loc[col1, col2] > threshold:
if np.random.rand() > 0.5:
G.add_edge(col1, col2)
else:
G.add_edge(col2, col1)
return nx.to_numpy_array(G) # Convert to adjacency matrix
def notears_algorithm(self, df):
"""
Placeholder for NOTEARS algorithm.
Returns a directed graph's adjacency matrix.
You would implement or integrate the NOTEARS algorithm here.
"""
# Example: W_est = notears_linear(data_array)
print("NOTEARS algorithm is a placeholder. Using a simplified correlation-based graph.")
G = nx.DiGraph()
nodes = df.columns
G.add_nodes_from(nodes)
corr_matrix = df.corr().abs()
threshold = 0.3
for i, col1 in enumerate(nodes):
for col2 in nodes[i+1:]:
if corr_matrix.loc[col1, col2] > threshold:
if np.random.rand() > 0.5:
G.add_edge(col1, col2)
else:
G.add_edge(col2, col1)
return nx.to_numpy_array(G) # Convert to adjacency matrix