Spaces:
Sleeping
Sleeping
major changes
Browse files- .env +1 -0
- .gitattributes +35 -35
- data/sample_dataset.csv +0 -0
- main.py +45 -37
- requirements.txt +14 -11
- routers/__pycache__/chatbot_routes.cpython-310.pyc +0 -0
- routers/__pycache__/discover_routes.cpython-310.pyc +0 -0
- routers/__pycache__/intervene_routes.cpython-310.pyc +0 -0
- routers/__pycache__/prediction_routes.cpython-310.pyc +0 -0
- routers/__pycache__/preprocess_routes.cpython-310.pyc +0 -0
- routers/__pycache__/timeseries_routes.cpython-310.pyc +0 -0
- routers/__pycache__/treatment_routes.cpython-310.pyc +0 -0
- routers/__pycache__/visualize_routes.cpython-310.pyc +0 -0
- routers/chatbot_routes.py +25 -0
- routers/discover_routes.py +42 -42
- routers/intervene_routes.py +53 -53
- routers/prediction_routes.py +27 -0
- routers/preprocess_routes.py +55 -55
- routers/timeseries_routes.py +30 -0
- routers/treatment_routes.py +53 -53
- routers/visualize_routes.py +42 -42
- scripts/generate_data.py +29 -29
- streamlit_app.py +618 -307
- utils/__pycache__/casual_algorithms.cpython-310.pyc +0 -0
- utils/__pycache__/causal_chatbot.cpython-310.pyc +0 -0
- utils/__pycache__/do_calculus.cpython-310.pyc +0 -0
- utils/__pycache__/graph_utils.cpython-310.pyc +0 -0
- utils/__pycache__/prediction_models.cpython-310.pyc +0 -0
- utils/__pycache__/preprocessor.cpython-310.pyc +0 -0
- utils/__pycache__/time_series_causal.cpython-310.pyc +0 -0
- utils/__pycache__/treatment_effects.cpython-310.pyc +0 -0
- utils/casual_algorithms.py +63 -63
- utils/causal_chatbot.py +271 -0
- utils/do_calculus.py +51 -51
- utils/graph_utils.py +107 -60
- utils/prediction_models.py +86 -0
- utils/preprocessor.py +88 -57
- utils/time_series_causal.py +102 -0
- utils/treatment_effects.py +62 -62
.env
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
GROQ_API_KEY=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
|
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
data/sample_dataset.csv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
main.py
CHANGED
@@ -1,38 +1,46 @@
|
|
1 |
-
# main.py
|
2 |
-
from flask import Flask, jsonify, request
|
3 |
-
from flask_cors import CORS
|
4 |
-
import os
|
5 |
-
import sys
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
from routers.
|
17 |
-
from routers.
|
18 |
-
from routers.
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
app
|
26 |
-
app
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
app.run(debug=True, host='0.0.0.0', port=5000)
|
|
|
1 |
+
# main.py
|
2 |
+
from flask import Flask, jsonify, request
|
3 |
+
from flask_cors import CORS
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
load_dotenv()
|
8 |
+
|
9 |
+
# Add the 'routers' and 'utils' directories to the Python path
|
10 |
+
# This allows direct imports like 'from routers.preprocess_routes import preprocess_bp'
|
11 |
+
script_dir = os.path.dirname(__file__)
|
12 |
+
sys.path.insert(0, os.path.join(script_dir, 'routers'))
|
13 |
+
sys.path.insert(0, os.path.join(script_dir, 'utils'))
|
14 |
+
|
15 |
+
# Import Blueprints
|
16 |
+
from routers.preprocess_routes import preprocess_bp
|
17 |
+
from routers.discover_routes import discover_bp
|
18 |
+
from routers.intervene_routes import intervene_bp
|
19 |
+
from routers.treatment_routes import treatment_bp
|
20 |
+
from routers.visualize_routes import visualize_bp
|
21 |
+
from routers.prediction_routes import prediction_bp
|
22 |
+
from routers.timeseries_routes import timeseries_bp
|
23 |
+
from routers.chatbot_routes import chatbot_bp
|
24 |
+
|
25 |
+
app = Flask(__name__)
|
26 |
+
CORS(app) # Enable CORS for frontend interaction
|
27 |
+
|
28 |
+
# Register Blueprints
|
29 |
+
app.register_blueprint(preprocess_bp, url_prefix='/preprocess')
|
30 |
+
app.register_blueprint(discover_bp, url_prefix='/discover')
|
31 |
+
app.register_blueprint(intervene_bp, url_prefix='/intervene')
|
32 |
+
app.register_blueprint(treatment_bp, url_prefix='/treatment')
|
33 |
+
app.register_blueprint(visualize_bp, url_prefix='/visualize')
|
34 |
+
app.register_blueprint(prediction_bp, url_prefix='/prediction')
|
35 |
+
app.register_blueprint(timeseries_bp, url_prefix='/timeseries')
|
36 |
+
app.register_blueprint(chatbot_bp, url_prefix='/chatbot')
|
37 |
+
|
38 |
+
@app.route('/')
|
39 |
+
def home():
|
40 |
+
return "Welcome to CausalBox Backend API!"
|
41 |
+
|
42 |
+
if __name__ == '__main__':
|
43 |
+
# Ensure the 'data' directory exists for storing datasets
|
44 |
+
os.makedirs('data', exist_ok=True)
|
45 |
+
# Run the Flask app
|
46 |
app.run(debug=True, host='0.0.0.0', port=5000)
|
requirements.txt
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
-
Flask
|
2 |
-
flask-cors
|
3 |
-
pandas
|
4 |
-
numpy
|
5 |
-
scikit-learn
|
6 |
-
causal-learn # For PC algorithm
|
7 |
-
networkx
|
8 |
-
plotly
|
9 |
-
streamlit
|
10 |
-
requests # For Streamlit to communicate with Flask
|
11 |
-
watchfiles # For auto_refresh.py (if implemented for background tasks)
|
|
|
|
|
|
|
|
1 |
+
Flask
|
2 |
+
flask-cors
|
3 |
+
pandas
|
4 |
+
numpy
|
5 |
+
scikit-learn
|
6 |
+
causal-learn # For PC algorithm
|
7 |
+
networkx
|
8 |
+
plotly
|
9 |
+
streamlit
|
10 |
+
requests # For Streamlit to communicate with Flask
|
11 |
+
watchfiles # For auto_refresh.py (if implemented for background tasks)
|
12 |
+
statsmodels # For statistical models and tests
|
13 |
+
google-generativeai
|
14 |
+
python-dotenv
|
routers/__pycache__/chatbot_routes.cpython-310.pyc
ADDED
Binary file (990 Bytes). View file
|
|
routers/__pycache__/discover_routes.cpython-310.pyc
CHANGED
Binary files a/routers/__pycache__/discover_routes.cpython-310.pyc and b/routers/__pycache__/discover_routes.cpython-310.pyc differ
|
|
routers/__pycache__/intervene_routes.cpython-310.pyc
CHANGED
Binary files a/routers/__pycache__/intervene_routes.cpython-310.pyc and b/routers/__pycache__/intervene_routes.cpython-310.pyc differ
|
|
routers/__pycache__/prediction_routes.cpython-310.pyc
ADDED
Binary file (1.13 kB). View file
|
|
routers/__pycache__/preprocess_routes.cpython-310.pyc
CHANGED
Binary files a/routers/__pycache__/preprocess_routes.cpython-310.pyc and b/routers/__pycache__/preprocess_routes.cpython-310.pyc differ
|
|
routers/__pycache__/timeseries_routes.cpython-310.pyc
ADDED
Binary file (1.29 kB). View file
|
|
routers/__pycache__/treatment_routes.cpython-310.pyc
CHANGED
Binary files a/routers/__pycache__/treatment_routes.cpython-310.pyc and b/routers/__pycache__/treatment_routes.cpython-310.pyc differ
|
|
routers/__pycache__/visualize_routes.cpython-310.pyc
CHANGED
Binary files a/routers/__pycache__/visualize_routes.cpython-310.pyc and b/routers/__pycache__/visualize_routes.cpython-310.pyc differ
|
|
routers/chatbot_routes.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# routers/chatbot_routes.py
|
2 |
+
from flask import Blueprint, request, jsonify
|
3 |
+
from utils.causal_chatbot import get_chatbot_response # Import the core chatbot logic
|
4 |
+
|
5 |
+
chatbot_bp = Blueprint('chatbot_bp', __name__)
|
6 |
+
|
7 |
+
@chatbot_bp.route('/message', methods=['POST'])
|
8 |
+
def handle_chat_message():
|
9 |
+
"""
|
10 |
+
API endpoint for the chatbot to receive user messages and provide responses.
|
11 |
+
"""
|
12 |
+
data = request.json
|
13 |
+
user_message = data.get('user_message')
|
14 |
+
# Session context includes processed_data, causal_graph_adj, etc.
|
15 |
+
session_context = data.get('session_context', {})
|
16 |
+
|
17 |
+
if not user_message:
|
18 |
+
return jsonify({"detail": "No user message provided."}), 400
|
19 |
+
|
20 |
+
try:
|
21 |
+
response_text = get_chatbot_response(user_message, session_context)
|
22 |
+
return jsonify({"response": response_text}), 200
|
23 |
+
except Exception as e:
|
24 |
+
print(f"Error in chatbot route: {e}")
|
25 |
+
return jsonify({"detail": f"An error occurred in the chatbot: {str(e)}"}), 500
|
routers/discover_routes.py
CHANGED
@@ -1,43 +1,43 @@
|
|
1 |
-
# routers/discover_routes.py
|
2 |
-
from flask import Blueprint, request, jsonify
|
3 |
-
import pandas as pd
|
4 |
-
from utils.casual_algorithms import CausalDiscoveryAlgorithms
|
5 |
-
import logging
|
6 |
-
|
7 |
-
discover_bp = Blueprint('discover', __name__)
|
8 |
-
logger = logging.getLogger(__name__)
|
9 |
-
|
10 |
-
causal_discovery_algorithms = CausalDiscoveryAlgorithms()
|
11 |
-
|
12 |
-
@discover_bp.route('/', methods=['POST'])
|
13 |
-
def discover_causal_graph():
|
14 |
-
"""
|
15 |
-
Discover causal graph from input data using selected algorithm.
|
16 |
-
Expects 'data' key with list of dicts (preprocessed DataFrame records) and 'algorithm' string.
|
17 |
-
Returns graph as adjacency matrix.
|
18 |
-
"""
|
19 |
-
try:
|
20 |
-
payload = request.json
|
21 |
-
if not payload or 'data' not in payload:
|
22 |
-
return jsonify({"detail": "Invalid request payload: 'data' key missing."}), 400
|
23 |
-
|
24 |
-
df = pd.DataFrame(payload["data"])
|
25 |
-
algorithm = payload.get("algorithm", "pc").lower() # Default to PC
|
26 |
-
|
27 |
-
logger.info(f"Received discovery request with algorithm: {algorithm}, data shape: {df.shape}")
|
28 |
-
|
29 |
-
if algorithm == "pc":
|
30 |
-
adj_matrix = causal_discovery_algorithms.pc_algorithm(df)
|
31 |
-
elif algorithm == "ges":
|
32 |
-
adj_matrix = causal_discovery_algorithms.ges_algorithm(df) # Placeholder
|
33 |
-
elif algorithm == "notears":
|
34 |
-
adj_matrix = causal_discovery_algorithms.notears_algorithm(df) # Placeholder
|
35 |
-
else:
|
36 |
-
return jsonify({"detail": f"Unsupported causal discovery algorithm: {algorithm}"}), 400
|
37 |
-
|
38 |
-
logger.info(f"Causal graph discovered using {algorithm}.")
|
39 |
-
return jsonify({"graph": adj_matrix.tolist()})
|
40 |
-
|
41 |
-
except Exception as e:
|
42 |
-
logger.exception(f"Error in causal discovery: {str(e)}")
|
43 |
return jsonify({"detail": f"Causal discovery failed: {str(e)}"}), 500
|
|
|
1 |
+
# routers/discover_routes.py
|
2 |
+
from flask import Blueprint, request, jsonify
|
3 |
+
import pandas as pd
|
4 |
+
from utils.casual_algorithms import CausalDiscoveryAlgorithms
|
5 |
+
import logging
|
6 |
+
|
7 |
+
discover_bp = Blueprint('discover', __name__)
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
causal_discovery_algorithms = CausalDiscoveryAlgorithms()
|
11 |
+
|
12 |
+
@discover_bp.route('/', methods=['POST'])
|
13 |
+
def discover_causal_graph():
|
14 |
+
"""
|
15 |
+
Discover causal graph from input data using selected algorithm.
|
16 |
+
Expects 'data' key with list of dicts (preprocessed DataFrame records) and 'algorithm' string.
|
17 |
+
Returns graph as adjacency matrix.
|
18 |
+
"""
|
19 |
+
try:
|
20 |
+
payload = request.json
|
21 |
+
if not payload or 'data' not in payload:
|
22 |
+
return jsonify({"detail": "Invalid request payload: 'data' key missing."}), 400
|
23 |
+
|
24 |
+
df = pd.DataFrame(payload["data"])
|
25 |
+
algorithm = payload.get("algorithm", "pc").lower() # Default to PC
|
26 |
+
|
27 |
+
logger.info(f"Received discovery request with algorithm: {algorithm}, data shape: {df.shape}")
|
28 |
+
|
29 |
+
if algorithm == "pc":
|
30 |
+
adj_matrix = causal_discovery_algorithms.pc_algorithm(df)
|
31 |
+
elif algorithm == "ges":
|
32 |
+
adj_matrix = causal_discovery_algorithms.ges_algorithm(df) # Placeholder
|
33 |
+
elif algorithm == "notears":
|
34 |
+
adj_matrix = causal_discovery_algorithms.notears_algorithm(df) # Placeholder
|
35 |
+
else:
|
36 |
+
return jsonify({"detail": f"Unsupported causal discovery algorithm: {algorithm}"}), 400
|
37 |
+
|
38 |
+
logger.info(f"Causal graph discovered using {algorithm}.")
|
39 |
+
return jsonify({"graph": adj_matrix.tolist()})
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
logger.exception(f"Error in causal discovery: {str(e)}")
|
43 |
return jsonify({"detail": f"Causal discovery failed: {str(e)}"}), 500
|
routers/intervene_routes.py
CHANGED
@@ -1,54 +1,54 @@
|
|
1 |
-
# routers/intervene_routes.py
|
2 |
-
from flask import Blueprint, request, jsonify
|
3 |
-
import pandas as pd
|
4 |
-
from utils.do_calculus import DoCalculus # Will be used for more advanced intervention
|
5 |
-
import networkx as nx # Assuming graph is passed or re-discovered
|
6 |
-
import logging
|
7 |
-
|
8 |
-
intervene_bp = Blueprint('intervene', __name__)
|
9 |
-
logger = logging.getLogger(__name__)
|
10 |
-
|
11 |
-
@intervene_bp.route('/', methods=['POST'])
|
12 |
-
def perform_intervention():
|
13 |
-
"""
|
14 |
-
Perform causal intervention on data.
|
15 |
-
Expects 'data' (list of dicts), 'intervention_var' (column name),
|
16 |
-
'intervention_value' (numeric), and optionally 'graph' (adjacency matrix).
|
17 |
-
Returns intervened data as list of dicts.
|
18 |
-
"""
|
19 |
-
try:
|
20 |
-
payload = request.json
|
21 |
-
if not payload or 'data' not in payload or 'intervention_var' not in payload or 'intervention_value' not in payload:
|
22 |
-
return jsonify({"detail": "Missing required intervention parameters."}), 400
|
23 |
-
|
24 |
-
df = pd.DataFrame(payload["data"])
|
25 |
-
intervention_var = payload["intervention_var"]
|
26 |
-
intervention_value = payload["intervention_value"]
|
27 |
-
graph_adj_matrix = payload.get("graph") # Optional: pass pre-discovered graph
|
28 |
-
|
29 |
-
logger.info(f"Intervention request: var={intervention_var}, value={intervention_value}, data shape: {df.shape}")
|
30 |
-
|
31 |
-
if intervention_var not in df.columns:
|
32 |
-
return jsonify({"detail": f"Intervention variable '{intervention_var}' not found in data"}), 400
|
33 |
-
|
34 |
-
# For a more advanced do-calculus, you'd need the graph structure.
|
35 |
-
# Here, a simplified direct intervention is applied first.
|
36 |
-
# If graph_adj_matrix is provided, you could convert it to networkx.
|
37 |
-
# For full do-calculus, the DoCalculus class would need a proper graph.
|
38 |
-
|
39 |
-
df_intervened = df.copy()
|
40 |
-
df_intervened[intervention_var] = intervention_value
|
41 |
-
|
42 |
-
# Placeholder for propagating effects using a graph if provided
|
43 |
-
# if graph_adj_matrix:
|
44 |
-
# graph_nx = nx.from_numpy_array(np.array(graph_adj_matrix), create_using=nx.DiGraph)
|
45 |
-
# do_calculus_engine = DoCalculus(graph_nx)
|
46 |
-
# df_intervened = do_calculus_engine.intervene(df_intervened, intervention_var, intervention_value)
|
47 |
-
# logger.info("Propagated effects using do-calculus (simplified).")
|
48 |
-
|
49 |
-
logger.info(f"Intervened data shape: {df_intervened.shape}")
|
50 |
-
return jsonify({"intervened_data": df_intervened.to_dict(orient="records")})
|
51 |
-
|
52 |
-
except Exception as e:
|
53 |
-
logger.exception(f"Error in intervention: {str(e)}")
|
54 |
return jsonify({"detail": f"Intervention failed: {str(e)}"}), 500
|
|
|
1 |
+
# routers/intervene_routes.py
|
2 |
+
from flask import Blueprint, request, jsonify
|
3 |
+
import pandas as pd
|
4 |
+
from utils.do_calculus import DoCalculus # Will be used for more advanced intervention
|
5 |
+
import networkx as nx # Assuming graph is passed or re-discovered
|
6 |
+
import logging
|
7 |
+
|
8 |
+
intervene_bp = Blueprint('intervene', __name__)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
@intervene_bp.route('/', methods=['POST'])
|
12 |
+
def perform_intervention():
|
13 |
+
"""
|
14 |
+
Perform causal intervention on data.
|
15 |
+
Expects 'data' (list of dicts), 'intervention_var' (column name),
|
16 |
+
'intervention_value' (numeric), and optionally 'graph' (adjacency matrix).
|
17 |
+
Returns intervened data as list of dicts.
|
18 |
+
"""
|
19 |
+
try:
|
20 |
+
payload = request.json
|
21 |
+
if not payload or 'data' not in payload or 'intervention_var' not in payload or 'intervention_value' not in payload:
|
22 |
+
return jsonify({"detail": "Missing required intervention parameters."}), 400
|
23 |
+
|
24 |
+
df = pd.DataFrame(payload["data"])
|
25 |
+
intervention_var = payload["intervention_var"]
|
26 |
+
intervention_value = payload["intervention_value"]
|
27 |
+
graph_adj_matrix = payload.get("graph") # Optional: pass pre-discovered graph
|
28 |
+
|
29 |
+
logger.info(f"Intervention request: var={intervention_var}, value={intervention_value}, data shape: {df.shape}")
|
30 |
+
|
31 |
+
if intervention_var not in df.columns:
|
32 |
+
return jsonify({"detail": f"Intervention variable '{intervention_var}' not found in data"}), 400
|
33 |
+
|
34 |
+
# For a more advanced do-calculus, you'd need the graph structure.
|
35 |
+
# Here, a simplified direct intervention is applied first.
|
36 |
+
# If graph_adj_matrix is provided, you could convert it to networkx.
|
37 |
+
# For full do-calculus, the DoCalculus class would need a proper graph.
|
38 |
+
|
39 |
+
df_intervened = df.copy()
|
40 |
+
df_intervened[intervention_var] = intervention_value
|
41 |
+
|
42 |
+
# Placeholder for propagating effects using a graph if provided
|
43 |
+
# if graph_adj_matrix:
|
44 |
+
# graph_nx = nx.from_numpy_array(np.array(graph_adj_matrix), create_using=nx.DiGraph)
|
45 |
+
# do_calculus_engine = DoCalculus(graph_nx)
|
46 |
+
# df_intervened = do_calculus_engine.intervene(df_intervened, intervention_var, intervention_value)
|
47 |
+
# logger.info("Propagated effects using do-calculus (simplified).")
|
48 |
+
|
49 |
+
logger.info(f"Intervened data shape: {df_intervened.shape}")
|
50 |
+
return jsonify({"intervened_data": df_intervened.to_dict(orient="records")})
|
51 |
+
|
52 |
+
except Exception as e:
|
53 |
+
logger.exception(f"Error in intervention: {str(e)}")
|
54 |
return jsonify({"detail": f"Intervention failed: {str(e)}"}), 500
|
routers/prediction_routes.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# routers/prediction_routes.py
|
2 |
+
from flask import Blueprint, request, jsonify
|
3 |
+
import pandas as pd
|
4 |
+
from utils.prediction_models import train_predict_random_forest
|
5 |
+
|
6 |
+
prediction_bp = Blueprint('prediction_bp', __name__)
|
7 |
+
|
8 |
+
@prediction_bp.route('/train_predict', methods=['POST'])
|
9 |
+
def train_predict():
|
10 |
+
"""
|
11 |
+
API endpoint to train a Random Forest model and perform prediction/evaluation.
|
12 |
+
"""
|
13 |
+
data = request.json.get('data')
|
14 |
+
target_col = request.json.get('target_col')
|
15 |
+
feature_cols = request.json.get('feature_cols')
|
16 |
+
prediction_type = request.json.get('prediction_type')
|
17 |
+
|
18 |
+
if not all([data, target_col, feature_cols, prediction_type]):
|
19 |
+
return jsonify({"detail": "Missing required parameters for prediction."}), 400
|
20 |
+
|
21 |
+
try:
|
22 |
+
results = train_predict_random_forest(data, target_col, feature_cols, prediction_type)
|
23 |
+
return jsonify({"results": results}), 200
|
24 |
+
except ValueError as e:
|
25 |
+
return jsonify({"detail": str(e)}), 400
|
26 |
+
except Exception as e:
|
27 |
+
return jsonify({"detail": f"An error occurred during prediction: {str(e)}"}), 500
|
routers/preprocess_routes.py
CHANGED
@@ -1,56 +1,56 @@
|
|
1 |
-
# routers/preprocess_routes.py
|
2 |
-
from flask import Blueprint, request, jsonify
|
3 |
-
import pandas as pd
|
4 |
-
from utils.preprocessor import DataPreprocessor
|
5 |
-
import logging
|
6 |
-
|
7 |
-
preprocess_bp = Blueprint('preprocess', __name__)
|
8 |
-
|
9 |
-
# Set up logging
|
10 |
-
logging.basicConfig(level=logging.INFO)
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
-
|
13 |
-
preprocessor = DataPreprocessor()
|
14 |
-
|
15 |
-
@preprocess_bp.route('/upload', methods=['POST'])
|
16 |
-
def upload_file():
|
17 |
-
"""
|
18 |
-
Upload and preprocess a CSV file.
|
19 |
-
Returns preprocessed DataFrame columns and data as JSON.
|
20 |
-
Optional limit_rows to reduce response size for testing.
|
21 |
-
"""
|
22 |
-
if 'file' not in request.files:
|
23 |
-
return jsonify({"detail": "No file part in the request"}), 400
|
24 |
-
file = request.files['file']
|
25 |
-
if file.filename == '':
|
26 |
-
return jsonify({"detail": "No selected file"}), 400
|
27 |
-
if not file.filename.lower().endswith('.csv'):
|
28 |
-
return jsonify({"detail": "Only CSV files are supported"}), 400
|
29 |
-
|
30 |
-
limit_rows = request.args.get('limit_rows', type=int)
|
31 |
-
|
32 |
-
try:
|
33 |
-
logger.info(f"Received file: {file.filename}")
|
34 |
-
df = pd.read_csv(file)
|
35 |
-
logger.info(f"CSV read successfully, shape: {df.shape}")
|
36 |
-
|
37 |
-
processed_df = preprocessor.preprocess(df)
|
38 |
-
if limit_rows:
|
39 |
-
processed_df = processed_df.head(limit_rows)
|
40 |
-
logger.info(f"Limited to {limit_rows} rows.")
|
41 |
-
|
42 |
-
response = {
|
43 |
-
"columns": list(processed_df.columns),
|
44 |
-
"data": processed_df.to_dict(orient="records")
|
45 |
-
}
|
46 |
-
logger.info(f"Preprocessed {len(response['data'])} records.")
|
47 |
-
return jsonify(response)
|
48 |
-
except pd.errors.EmptyDataError:
|
49 |
-
logger.error("Empty CSV file uploaded.")
|
50 |
-
return jsonify({"detail": "Empty CSV file"}), 400
|
51 |
-
except pd.errors.ParserError:
|
52 |
-
logger.error("Invalid CSV format.")
|
53 |
-
return jsonify({"detail": "Invalid CSV format"}), 400
|
54 |
-
except Exception as e:
|
55 |
-
logger.exception(f"Unexpected error during file processing: {str(e)}")
|
56 |
return jsonify({"detail": f"Failed to process file: {str(e)}"}), 500
|
|
|
1 |
+
# routers/preprocess_routes.py
|
2 |
+
from flask import Blueprint, request, jsonify
|
3 |
+
import pandas as pd
|
4 |
+
from utils.preprocessor import DataPreprocessor
|
5 |
+
import logging
|
6 |
+
|
7 |
+
preprocess_bp = Blueprint('preprocess', __name__)
|
8 |
+
|
9 |
+
# Set up logging
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
preprocessor = DataPreprocessor()
|
14 |
+
|
15 |
+
@preprocess_bp.route('/upload', methods=['POST'])
|
16 |
+
def upload_file():
|
17 |
+
"""
|
18 |
+
Upload and preprocess a CSV file.
|
19 |
+
Returns preprocessed DataFrame columns and data as JSON.
|
20 |
+
Optional limit_rows to reduce response size for testing.
|
21 |
+
"""
|
22 |
+
if 'file' not in request.files:
|
23 |
+
return jsonify({"detail": "No file part in the request"}), 400
|
24 |
+
file = request.files['file']
|
25 |
+
if file.filename == '':
|
26 |
+
return jsonify({"detail": "No selected file"}), 400
|
27 |
+
if not file.filename.lower().endswith('.csv'):
|
28 |
+
return jsonify({"detail": "Only CSV files are supported"}), 400
|
29 |
+
|
30 |
+
limit_rows = request.args.get('limit_rows', type=int)
|
31 |
+
|
32 |
+
try:
|
33 |
+
logger.info(f"Received file: {file.filename}")
|
34 |
+
df = pd.read_csv(file)
|
35 |
+
logger.info(f"CSV read successfully, shape: {df.shape}")
|
36 |
+
|
37 |
+
processed_df = preprocessor.preprocess(df)
|
38 |
+
if limit_rows:
|
39 |
+
processed_df = processed_df.head(limit_rows)
|
40 |
+
logger.info(f"Limited to {limit_rows} rows.")
|
41 |
+
|
42 |
+
response = {
|
43 |
+
"columns": list(processed_df.columns),
|
44 |
+
"data": processed_df.to_dict(orient="records")
|
45 |
+
}
|
46 |
+
logger.info(f"Preprocessed {len(response['data'])} records.")
|
47 |
+
return jsonify(response)
|
48 |
+
except pd.errors.EmptyDataError:
|
49 |
+
logger.error("Empty CSV file uploaded.")
|
50 |
+
return jsonify({"detail": "Empty CSV file"}), 400
|
51 |
+
except pd.errors.ParserError:
|
52 |
+
logger.error("Invalid CSV format.")
|
53 |
+
return jsonify({"detail": "Invalid CSV format"}), 400
|
54 |
+
except Exception as e:
|
55 |
+
logger.exception(f"Unexpected error during file processing: {str(e)}")
|
56 |
return jsonify({"detail": f"Failed to process file: {str(e)}"}), 500
|
routers/timeseries_routes.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# routers/timeseries_routes.py
|
2 |
+
from flask import Blueprint, request, jsonify
|
3 |
+
import pandas as pd
|
4 |
+
from utils.time_series_causal import perform_granger_causality
|
5 |
+
|
6 |
+
timeseries_bp = Blueprint('timeseries_bp', __name__)
|
7 |
+
|
8 |
+
@timeseries_bp.route('/discover_causality', methods=['POST'])
|
9 |
+
def discover_timeseries_causality():
|
10 |
+
"""
|
11 |
+
API endpoint to perform time-series causal discovery (Granger Causality).
|
12 |
+
"""
|
13 |
+
data = request.json.get('data')
|
14 |
+
timestamp_col = request.json.get('timestamp_col')
|
15 |
+
variables_to_analyze = request.json.get('variables_to_analyze')
|
16 |
+
max_lags = request.json.get('max_lags', 1) # Default to 1 lag
|
17 |
+
|
18 |
+
if not all([data, timestamp_col, variables_to_analyze]):
|
19 |
+
return jsonify({"detail": "Missing required parameters for time-series causal discovery."}), 400
|
20 |
+
|
21 |
+
if not isinstance(max_lags, int) or max_lags <= 0:
|
22 |
+
return jsonify({"detail": "max_lags must be a positive integer."}), 400
|
23 |
+
|
24 |
+
try:
|
25 |
+
results = perform_granger_causality(data, timestamp_col, variables_to_analyze, max_lags)
|
26 |
+
return jsonify({"results": results}), 200
|
27 |
+
except ValueError as e:
|
28 |
+
return jsonify({"detail": str(e)}), 400
|
29 |
+
except Exception as e:
|
30 |
+
return jsonify({"detail": f"An error occurred during time-series causal discovery: {str(e)}"}), 500
|
routers/treatment_routes.py
CHANGED
@@ -1,54 +1,54 @@
|
|
1 |
-
# routers/treatment_routes.py
|
2 |
-
from flask import Blueprint, request, jsonify
|
3 |
-
import pandas as pd
|
4 |
-
from utils.treatment_effects import TreatmentEffectAlgorithms
|
5 |
-
import logging
|
6 |
-
|
7 |
-
treatment_bp = Blueprint('treatment', __name__)
|
8 |
-
logger = logging.getLogger(__name__)
|
9 |
-
|
10 |
-
treatment_effect_algorithms = TreatmentEffectAlgorithms()
|
11 |
-
|
12 |
-
@treatment_bp.route('/estimate_ate', methods=['POST'])
|
13 |
-
def estimate_ate():
|
14 |
-
"""
|
15 |
-
Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).
|
16 |
-
Expects 'data' (list of dicts), 'treatment_col', 'outcome_col', 'covariates' (list of column names),
|
17 |
-
and 'method' (string for estimation method).
|
18 |
-
Returns ATE/CATE as float or dictionary.
|
19 |
-
"""
|
20 |
-
try:
|
21 |
-
payload = request.json
|
22 |
-
if not payload or 'data' not in payload or 'treatment_col' not in payload or 'outcome_col' not in payload or 'covariates' not in payload:
|
23 |
-
return jsonify({"detail": "Missing required ATE estimation parameters."}), 400
|
24 |
-
|
25 |
-
df = pd.DataFrame(payload["data"])
|
26 |
-
treatment_col = payload["treatment_col"]
|
27 |
-
outcome_col = payload["outcome_col"]
|
28 |
-
covariates = payload["covariates"]
|
29 |
-
method = payload.get("method", "linear_regression").lower() # Default to linear regression
|
30 |
-
|
31 |
-
logger.info(f"ATE/CATE request: treatment={treatment_col}, outcome={outcome_col}, method={method}, data shape: {df.shape}")
|
32 |
-
|
33 |
-
if not all(col in df.columns for col in [treatment_col, outcome_col] + covariates):
|
34 |
-
return jsonify({"detail": "Invalid column names provided for ATE estimation."}), 400
|
35 |
-
|
36 |
-
if method == "linear_regression":
|
37 |
-
result = treatment_effect_algorithms.linear_regression_ate(df, treatment_col, outcome_col, covariates)
|
38 |
-
elif method == "propensity_score_matching":
|
39 |
-
result = treatment_effect_algorithms.propensity_score_matching(df, treatment_col, outcome_col, covariates) # Placeholder
|
40 |
-
elif method == "inverse_propensity_weighting":
|
41 |
-
result = treatment_effect_algorithms.inverse_propensity_weighting(df, treatment_col, outcome_col, covariates) # Placeholder
|
42 |
-
elif method == "t_learner":
|
43 |
-
result = treatment_effect_algorithms.t_learner(df, treatment_col, outcome_col, covariates) # Placeholder
|
44 |
-
elif method == "s_learner":
|
45 |
-
result = treatment_effect_algorithms.s_learner(df, treatment_col, outcome_col, covariates) # Placeholder
|
46 |
-
else:
|
47 |
-
return jsonify({"detail": f"Unsupported treatment effect estimation method: {method}"}), 400
|
48 |
-
|
49 |
-
logger.info(f"Estimated ATE/CATE using {method}: {result}")
|
50 |
-
return jsonify({"result": result})
|
51 |
-
|
52 |
-
except Exception as e:
|
53 |
-
logger.exception(f"Error in ATE/CATE estimation: {str(e)}")
|
54 |
return jsonify({"detail": f"ATE/CATE estimation failed: {str(e)}"}), 500
|
|
|
1 |
+
# routers/treatment_routes.py
|
2 |
+
from flask import Blueprint, request, jsonify
|
3 |
+
import pandas as pd
|
4 |
+
from utils.treatment_effects import TreatmentEffectAlgorithms
|
5 |
+
import logging
|
6 |
+
|
7 |
+
treatment_bp = Blueprint('treatment', __name__)
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
treatment_effect_algorithms = TreatmentEffectAlgorithms()
|
11 |
+
|
12 |
+
@treatment_bp.route('/estimate_ate', methods=['POST'])
|
13 |
+
def estimate_ate():
|
14 |
+
"""
|
15 |
+
Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).
|
16 |
+
Expects 'data' (list of dicts), 'treatment_col', 'outcome_col', 'covariates' (list of column names),
|
17 |
+
and 'method' (string for estimation method).
|
18 |
+
Returns ATE/CATE as float or dictionary.
|
19 |
+
"""
|
20 |
+
try:
|
21 |
+
payload = request.json
|
22 |
+
if not payload or 'data' not in payload or 'treatment_col' not in payload or 'outcome_col' not in payload or 'covariates' not in payload:
|
23 |
+
return jsonify({"detail": "Missing required ATE estimation parameters."}), 400
|
24 |
+
|
25 |
+
df = pd.DataFrame(payload["data"])
|
26 |
+
treatment_col = payload["treatment_col"]
|
27 |
+
outcome_col = payload["outcome_col"]
|
28 |
+
covariates = payload["covariates"]
|
29 |
+
method = payload.get("method", "linear_regression").lower() # Default to linear regression
|
30 |
+
|
31 |
+
logger.info(f"ATE/CATE request: treatment={treatment_col}, outcome={outcome_col}, method={method}, data shape: {df.shape}")
|
32 |
+
|
33 |
+
if not all(col in df.columns for col in [treatment_col, outcome_col] + covariates):
|
34 |
+
return jsonify({"detail": "Invalid column names provided for ATE estimation."}), 400
|
35 |
+
|
36 |
+
if method == "linear_regression":
|
37 |
+
result = treatment_effect_algorithms.linear_regression_ate(df, treatment_col, outcome_col, covariates)
|
38 |
+
elif method == "propensity_score_matching":
|
39 |
+
result = treatment_effect_algorithms.propensity_score_matching(df, treatment_col, outcome_col, covariates) # Placeholder
|
40 |
+
elif method == "inverse_propensity_weighting":
|
41 |
+
result = treatment_effect_algorithms.inverse_propensity_weighting(df, treatment_col, outcome_col, covariates) # Placeholder
|
42 |
+
elif method == "t_learner":
|
43 |
+
result = treatment_effect_algorithms.t_learner(df, treatment_col, outcome_col, covariates) # Placeholder
|
44 |
+
elif method == "s_learner":
|
45 |
+
result = treatment_effect_algorithms.s_learner(df, treatment_col, outcome_col, covariates) # Placeholder
|
46 |
+
else:
|
47 |
+
return jsonify({"detail": f"Unsupported treatment effect estimation method: {method}"}), 400
|
48 |
+
|
49 |
+
logger.info(f"Estimated ATE/CATE using {method}: {result}")
|
50 |
+
return jsonify({"result": result})
|
51 |
+
|
52 |
+
except Exception as e:
|
53 |
+
logger.exception(f"Error in ATE/CATE estimation: {str(e)}")
|
54 |
return jsonify({"detail": f"ATE/CATE estimation failed: {str(e)}"}), 500
|
routers/visualize_routes.py
CHANGED
@@ -1,43 +1,43 @@
|
|
1 |
-
# routers/visualize_routes.py
|
2 |
-
from flask import Blueprint, request, jsonify
|
3 |
-
import pandas as pd
|
4 |
-
from utils.graph_utils import visualize_graph
|
5 |
-
import networkx as nx
|
6 |
-
import numpy as np
|
7 |
-
import logging
|
8 |
-
|
9 |
-
visualize_bp = Blueprint('visualize', __name__)
|
10 |
-
logger = logging.getLogger(__name__)
|
11 |
-
|
12 |
-
@visualize_bp.route('/graph', methods=['POST'])
|
13 |
-
def get_graph_visualization():
|
14 |
-
"""
|
15 |
-
Generate a causal graph visualization from an adjacency matrix.
|
16 |
-
Expects 'graph' (adjacency matrix as list of lists) and 'nodes' (list of node names).
|
17 |
-
Returns Plotly JSON for the graph.
|
18 |
-
"""
|
19 |
-
try:
|
20 |
-
payload = request.json
|
21 |
-
if not payload or 'graph' not in payload or 'nodes' not in payload:
|
22 |
-
return jsonify({"detail": "Missing 'graph' or 'nodes' in request payload."}), 400
|
23 |
-
|
24 |
-
adj_matrix = np.array(payload["graph"])
|
25 |
-
nodes = payload["nodes"]
|
26 |
-
|
27 |
-
logger.info(f"Received graph visualization request for {len(nodes)} nodes.")
|
28 |
-
|
29 |
-
# Reconstruct networkx graph from adjacency matrix and node names
|
30 |
-
graph_nx = nx.from_numpy_array(adj_matrix, create_using=nx.DiGraph)
|
31 |
-
|
32 |
-
# Map integer node labels back to original column names if necessary
|
33 |
-
# Assuming nodes are ordered as they appear in the original dataframe or provided in 'nodes'
|
34 |
-
mapping = {i: node_name for i, node_name in enumerate(nodes)}
|
35 |
-
graph_nx = nx.relabel_nodes(graph_nx, mapping)
|
36 |
-
|
37 |
-
graph_json = visualize_graph(graph_nx)
|
38 |
-
logger.info("Generated graph visualization JSON.")
|
39 |
-
return jsonify({"graph": graph_json})
|
40 |
-
|
41 |
-
except Exception as e:
|
42 |
-
logger.exception(f"Error generating graph visualization: {str(e)}")
|
43 |
return jsonify({"detail": f"Failed to generate visualization: {str(e)}"}), 500
|
|
|
1 |
+
# routers/visualize_routes.py
|
2 |
+
from flask import Blueprint, request, jsonify
|
3 |
+
import pandas as pd
|
4 |
+
from utils.graph_utils import visualize_graph
|
5 |
+
import networkx as nx
|
6 |
+
import numpy as np
|
7 |
+
import logging
|
8 |
+
|
9 |
+
visualize_bp = Blueprint('visualize', __name__)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
@visualize_bp.route('/graph', methods=['POST'])
|
13 |
+
def get_graph_visualization():
|
14 |
+
"""
|
15 |
+
Generate a causal graph visualization from an adjacency matrix.
|
16 |
+
Expects 'graph' (adjacency matrix as list of lists) and 'nodes' (list of node names).
|
17 |
+
Returns Plotly JSON for the graph.
|
18 |
+
"""
|
19 |
+
try:
|
20 |
+
payload = request.json
|
21 |
+
if not payload or 'graph' not in payload or 'nodes' not in payload:
|
22 |
+
return jsonify({"detail": "Missing 'graph' or 'nodes' in request payload."}), 400
|
23 |
+
|
24 |
+
adj_matrix = np.array(payload["graph"])
|
25 |
+
nodes = payload["nodes"]
|
26 |
+
|
27 |
+
logger.info(f"Received graph visualization request for {len(nodes)} nodes.")
|
28 |
+
|
29 |
+
# Reconstruct networkx graph from adjacency matrix and node names
|
30 |
+
graph_nx = nx.from_numpy_array(adj_matrix, create_using=nx.DiGraph)
|
31 |
+
|
32 |
+
# Map integer node labels back to original column names if necessary
|
33 |
+
# Assuming nodes are ordered as they appear in the original dataframe or provided in 'nodes'
|
34 |
+
mapping = {i: node_name for i, node_name in enumerate(nodes)}
|
35 |
+
graph_nx = nx.relabel_nodes(graph_nx, mapping)
|
36 |
+
|
37 |
+
graph_json = visualize_graph(graph_nx)
|
38 |
+
logger.info("Generated graph visualization JSON.")
|
39 |
+
return jsonify({"graph": graph_json})
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
logger.exception(f"Error generating graph visualization: {str(e)}")
|
43 |
return jsonify({"detail": f"Failed to generate visualization: {str(e)}"}), 500
|
scripts/generate_data.py
CHANGED
@@ -1,29 +1,29 @@
|
|
1 |
-
# scripts/generate_data.py
|
2 |
-
import numpy as np
|
3 |
-
import pandas as pd
|
4 |
-
import os
|
5 |
-
|
6 |
-
def generate_dataset(n_samples=1000):
|
7 |
-
np.random.seed(42)
|
8 |
-
study_hours = np.random.normal(10, 2, n_samples)
|
9 |
-
tuition_hours = np.random.normal(5, 1, n_samples)
|
10 |
-
parental_education = np.random.choice(['High', 'Medium', 'Low'], n_samples)
|
11 |
-
school_type = np.random.choice(['Public', 'Private'], n_samples)
|
12 |
-
exam_score = 50 + 2 * study_hours + 1.5 * tuition_hours + np.random.normal(0, 5, n_samples)
|
13 |
-
|
14 |
-
df = pd.DataFrame({
|
15 |
-
'StudyHours': study_hours,
|
16 |
-
'TuitionHours': tuition_hours,
|
17 |
-
'ParentalEducation': parental_education,
|
18 |
-
'SchoolType': school_type,
|
19 |
-
'FinalExamScore': exam_score
|
20 |
-
})
|
21 |
-
|
22 |
-
# Ensure data directory exists
|
23 |
-
os.makedirs('../data', exist_ok=True)
|
24 |
-
df.to_csv('../data/sample_dataset.csv', index=False)
|
25 |
-
return df
|
26 |
-
|
27 |
-
if __name__ == "__main__":
|
28 |
-
generate_dataset()
|
29 |
-
print("Dataset generated and saved to ../data/sample_dataset.csv")
|
|
|
1 |
+
# scripts/generate_data.py
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import os
|
5 |
+
|
6 |
+
def generate_dataset(n_samples=1000):
|
7 |
+
np.random.seed(42)
|
8 |
+
study_hours = np.random.normal(10, 2, n_samples)
|
9 |
+
tuition_hours = np.random.normal(5, 1, n_samples)
|
10 |
+
parental_education = np.random.choice(['High', 'Medium', 'Low'], n_samples)
|
11 |
+
school_type = np.random.choice(['Public', 'Private'], n_samples)
|
12 |
+
exam_score = 50 + 2 * study_hours + 1.5 * tuition_hours + np.random.normal(0, 5, n_samples)
|
13 |
+
|
14 |
+
df = pd.DataFrame({
|
15 |
+
'StudyHours': study_hours,
|
16 |
+
'TuitionHours': tuition_hours,
|
17 |
+
'ParentalEducation': parental_education,
|
18 |
+
'SchoolType': school_type,
|
19 |
+
'FinalExamScore': exam_score
|
20 |
+
})
|
21 |
+
|
22 |
+
# Ensure data directory exists
|
23 |
+
os.makedirs('../data', exist_ok=True)
|
24 |
+
df.to_csv('../data/sample_dataset.csv', index=False)
|
25 |
+
return df
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
generate_dataset()
|
29 |
+
print("Dataset generated and saved to ../data/sample_dataset.csv")
|
streamlit_app.py
CHANGED
@@ -1,308 +1,619 @@
|
|
1 |
-
# streamlit_app.py
|
2 |
-
import streamlit as st
|
3 |
-
import pandas as pd
|
4 |
-
import requests
|
5 |
-
import json
|
6 |
-
import plotly.express as px
|
7 |
-
import plotly.graph_objects as go
|
8 |
-
import numpy as np # For random array in placeholders
|
9 |
-
import os
|
10 |
-
|
11 |
-
# Configuration
|
12 |
-
FLASK_API_URL = "http://localhost:5000" # Ensure this matches your Flask app's host and port
|
13 |
-
|
14 |
-
st.set_page_config(layout="wide", page_title="CausalBox Toolkit")
|
15 |
-
|
16 |
-
st.title("🔬 CausalBox: A Causal Inference Toolkit")
|
17 |
-
st.markdown("Uncover causal relationships, simulate interventions, and estimate treatment effects.")
|
18 |
-
|
19 |
-
# --- Session State Initialization ---
|
20 |
-
if 'processed_data' not in st.session_state:
|
21 |
-
st.session_state.processed_data = None
|
22 |
-
if 'processed_columns' not in st.session_state:
|
23 |
-
st.session_state.processed_columns = None
|
24 |
-
if 'causal_graph_adj' not in st.session_state:
|
25 |
-
st.session_state.causal_graph_adj = None
|
26 |
-
if 'causal_graph_nodes' not in st.session_state:
|
27 |
-
st.session_state.causal_graph_nodes = None
|
28 |
-
|
29 |
-
# --- Data Preprocessing Module ---
|
30 |
-
st.header("1. Data Preprocessor 🧹")
|
31 |
-
st.write("Upload your CSV dataset or use a generated sample dataset.")
|
32 |
-
|
33 |
-
# Option to use generated sample dataset
|
34 |
-
if st.button("Use Sample Dataset (sample_dataset.csv)"):
|
35 |
-
# In a real scenario, Streamlit would serve the file or you'd load it directly if local.
|
36 |
-
# For this setup, we assume the Flask backend can access it or you manually upload it once.
|
37 |
-
# For demonstration, we'll simulate loading a generic DataFrame.
|
38 |
-
# In a full deployment, you'd have a mechanism to either:
|
39 |
-
# a) Have Flask serve the sample file, or
|
40 |
-
# b) Directly load it in Streamlit if the app and data are co-located.
|
41 |
-
try:
|
42 |
-
# Assuming the sample dataset is accessible or you are testing locally with `scripts/generate_data.py`
|
43 |
-
# and then manually uploading this generated file.
|
44 |
-
# For simplicity, we'll create a dummy df here if not actually uploaded.
|
45 |
-
sample_df_path = "data/sample_dataset.csv" # Path relative to main.py or Streamlit app execution
|
46 |
-
if os.path.exists(sample_df_path):
|
47 |
-
sample_df = pd.read_csv(sample_df_path)
|
48 |
-
st.success(f"Loaded sample dataset from {sample_df_path}. Please upload this file if running from different directory.")
|
49 |
-
else:
|
50 |
-
st.warning("Sample dataset not found at data/sample_dataset.csv.")
|
51 |
-
# Dummy DataFrame for demonstration if sample file isn't found
|
52 |
-
sample_df = pd.DataFrame(np.random.rand(10, 5), columns=[f'col_{i}' for i in range(5)])
|
53 |
-
|
54 |
-
# Convert to JSON for Flask API call
|
55 |
-
files = {'file': ('sample_dataset.csv', sample_df.to_csv(index=False), 'text/csv')}
|
56 |
-
response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
|
57 |
-
|
58 |
-
if response.status_code == 200:
|
59 |
-
result = response.json()
|
60 |
-
st.session_state.processed_data = result['data']
|
61 |
-
st.session_state.processed_columns = result['columns']
|
62 |
-
st.success("Sample dataset preprocessed successfully!")
|
63 |
-
st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
|
64 |
-
else:
|
65 |
-
st.error(f"Error preprocessing sample dataset: {response.json().get('detail', 'Unknown error')}")
|
66 |
-
except Exception as e:
|
67 |
-
st.error(f"Could not load or process sample dataset: {e}")
|
68 |
-
|
69 |
-
|
70 |
-
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
|
71 |
-
if uploaded_file is not None:
|
72 |
-
st.info("Uploading and preprocessing data...")
|
73 |
-
files = {'file': (uploaded_file.name, uploaded_file.getvalue(), 'text/csv')}
|
74 |
-
try:
|
75 |
-
response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
|
76 |
-
if response.status_code == 200:
|
77 |
-
result = response.json()
|
78 |
-
st.session_state.processed_data = result['data']
|
79 |
-
st.session_state.processed_columns = result['columns']
|
80 |
-
st.success("File preprocessed successfully!")
|
81 |
-
st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
|
82 |
-
else:
|
83 |
-
st.error(f"Error during preprocessing: {response.json().get('detail', 'Unknown error')}")
|
84 |
-
except requests.exceptions.ConnectionError:
|
85 |
-
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
86 |
-
except Exception as e:
|
87 |
-
st.error(f"An unexpected error occurred: {e}")
|
88 |
-
|
89 |
-
# --- Causal Discovery Module ---
|
90 |
-
st.header("2. Causal Discovery 🕵️♂️")
|
91 |
-
if st.session_state.processed_data:
|
92 |
-
st.write("Learn the causal structure from your preprocessed data.")
|
93 |
-
|
94 |
-
discovery_algo = st.selectbox(
|
95 |
-
"Select Causal Discovery Algorithm:",
|
96 |
-
("PC Algorithm", "GES (Greedy Equivalence Search) - Placeholder", "NOTEARS - Placeholder")
|
97 |
-
)
|
98 |
-
|
99 |
-
if st.button("Discover Causal Graph"):
|
100 |
-
st.info(f"Discovering graph using {discovery_algo}...")
|
101 |
-
algo_map = {
|
102 |
-
"PC Algorithm": "pc",
|
103 |
-
"GES (Greedy Equivalence Search) - Placeholder": "ges",
|
104 |
-
"NOTEARS - Placeholder": "notears"
|
105 |
-
}
|
106 |
-
selected_algo_code = algo_map[discovery_algo]
|
107 |
-
|
108 |
-
try:
|
109 |
-
response = requests.post(
|
110 |
-
f"{FLASK_API_URL}/discover/",
|
111 |
-
json={"data": st.session_state.processed_data, "algorithm": selected_algo_code}
|
112 |
-
)
|
113 |
-
if response.status_code == 200:
|
114 |
-
result = response.json()
|
115 |
-
st.session_state.causal_graph_adj = result['graph']
|
116 |
-
st.session_state.causal_graph_nodes = st.session_state.processed_columns
|
117 |
-
st.success("Causal graph discovered!")
|
118 |
-
st.subheader("Causal Graph Visualization")
|
119 |
-
# Visualization will be handled by the Causal Graph Visualizer section
|
120 |
-
else:
|
121 |
-
st.error(f"Error during causal discovery: {response.json().get('detail', 'Unknown error')}")
|
122 |
-
except requests.exceptions.ConnectionError:
|
123 |
-
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
124 |
-
except Exception as e:
|
125 |
-
st.error(f"An unexpected error occurred: {e}")
|
126 |
-
else:
|
127 |
-
st.info("Please preprocess data first to enable causal discovery.")
|
128 |
-
|
129 |
-
# --- Causal Graph Visualizer Module ---
|
130 |
-
st.header("3. Causal Graph Visualizer 📊")
|
131 |
-
if st.session_state.causal_graph_adj and st.session_state.causal_graph_nodes:
|
132 |
-
st.write("Interactive visualization of the discovered causal graph.")
|
133 |
-
try:
|
134 |
-
response = requests.post(
|
135 |
-
f"{FLASK_API_URL}/visualize/graph",
|
136 |
-
json={"graph": st.session_state.causal_graph_adj, "nodes": st.session_state.causal_graph_nodes}
|
137 |
-
)
|
138 |
-
if response.status_code == 200:
|
139 |
-
graph_json = response.json()['graph']
|
140 |
-
fig = go.Figure(json.loads(graph_json))
|
141 |
-
st.plotly_chart(fig, use_container_width=True)
|
142 |
-
st.markdown("""
|
143 |
-
**Graph Explanation:**
|
144 |
-
* **Nodes:** Represent variables in your dataset.
|
145 |
-
* **Arrows (Edges):** Indicate a direct causal influence from one variable (the tail) to another (the head).
|
146 |
-
* **No Arrow:** Suggests no direct causal relationship was found, or the relationship is mediated by other variables.
|
147 |
-
|
148 |
-
This graph helps answer "Why did it happen?" by showing the structural relationships.
|
149 |
-
""")
|
150 |
-
else:
|
151 |
-
st.error(f"Error visualizing graph: {response.json().get('detail', 'Unknown error')}")
|
152 |
-
except requests.exceptions.ConnectionError:
|
153 |
-
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
154 |
-
except Exception as e:
|
155 |
-
st.error(f"An unexpected error occurred during visualization: {e}")
|
156 |
-
else:
|
157 |
-
st.info("Please discover a causal graph first to visualize it.")
|
158 |
-
|
159 |
-
|
160 |
-
# --- Do-Calculus Engine Module ---
|
161 |
-
st.header("4. Do-Calculus Engine 🧪")
|
162 |
-
if st.session_state.processed_data and st.session_state.causal_graph_adj:
|
163 |
-
st.write("Simulate interventions and observe their effects based on the causal graph.")
|
164 |
-
|
165 |
-
intervention_var = st.selectbox(
|
166 |
-
"Select variable to intervene on:",
|
167 |
-
st.session_state.processed_columns,
|
168 |
-
key="inter_var_select"
|
169 |
-
)
|
170 |
-
# Attempt to infer type for intervention_value input
|
171 |
-
# Simplified approach: assuming numerical for now due to preprocessor output
|
172 |
-
if intervention_var and isinstance(st.session_state.processed_data[0][intervention_var], (int, float)):
|
173 |
-
intervention_value = st.number_input(f"Set '{intervention_var}' to value:", key="inter_val_input")
|
174 |
-
else: # Treat as string/categorical for input, then try to preprocess for API
|
175 |
-
intervention_value = st.text_input(f"Set '{intervention_var}' to value:", key="inter_val_input_text")
|
176 |
-
st.warning("Categorical intervention values might require specific encoding logic on the backend.")
|
177 |
-
|
178 |
-
if st.button("Perform Intervention"):
|
179 |
-
st.info(f"Performing intervention: do('{intervention_var}' = {intervention_value})...")
|
180 |
-
try:
|
181 |
-
response = requests.post(
|
182 |
-
f"{FLASK_API_URL}/intervene/",
|
183 |
-
json={
|
184 |
-
"data": st.session_state.processed_data,
|
185 |
-
"intervention_var": intervention_var,
|
186 |
-
"intervention_value": intervention_value,
|
187 |
-
"graph": st.session_state.causal_graph_adj # Pass graph for advanced do-calculus
|
188 |
-
}
|
189 |
-
)
|
190 |
-
if response.status_code == 200:
|
191 |
-
intervened_data = pd.DataFrame(response.json()['intervened_data'])
|
192 |
-
st.success("Intervention simulated successfully!")
|
193 |
-
st.subheader("Intervened Data (First 10 rows)")
|
194 |
-
st.dataframe(intervened_data.head(10))
|
195 |
-
|
196 |
-
# Simple comparison visualization (e.g., histogram of outcome variable)
|
197 |
-
if st.session_state.processed_columns and 'FinalExamScore' in st.session_state.processed_columns:
|
198 |
-
original_df = pd.DataFrame(st.session_state.processed_data)
|
199 |
-
fig_dist = go.Figure()
|
200 |
-
fig_dist.add_trace(go.Histogram(x=original_df['FinalExamScore'], name='Original', opacity=0.7))
|
201 |
-
fig_dist.add_trace(go.Histogram(x=intervened_data['FinalExamScore'], name='Intervened', opacity=0.0))
|
202 |
-
|
203 |
-
st.plotly_chart(fig_dist, use_container_width=True)
|
204 |
-
st.markdown("""
|
205 |
-
**Intervention Explanation:**
|
206 |
-
* By simulating `do(X=x)`, we are forcing the value of X, effectively breaking its causal links from its parents.
|
207 |
-
* The graph above shows the distribution of a key outcome variable (e.g., `FinalExamScore`) before and after the intervention.
|
208 |
-
* This helps answer "What if we do this instead?" by showing the predicted outcome.
|
209 |
-
""")
|
210 |
-
else:
|
211 |
-
st.info("Consider adding a relevant outcome variable to your dataset for better intervention analysis.")
|
212 |
-
else:
|
213 |
-
st.error(f"Error during intervention: {response.json().get('detail', 'Unknown error')}")
|
214 |
-
except requests.exceptions.ConnectionError:
|
215 |
-
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
216 |
-
except Exception as e:
|
217 |
-
st.error(f"An unexpected error occurred during intervention: {e}")
|
218 |
-
else:
|
219 |
-
st.info("Please preprocess data and discover a causal graph first to perform interventions.")
|
220 |
-
|
221 |
-
# --- Treatment Effect Estimator Module ---
|
222 |
-
st.header("5. Treatment Effect Estimator 🎯")
|
223 |
-
if st.session_state.processed_data:
|
224 |
-
st.write("Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).")
|
225 |
-
|
226 |
-
col1, col2 = st.columns(2)
|
227 |
-
with col1:
|
228 |
-
treatment_col = st.selectbox(
|
229 |
-
"Select Treatment Variable:",
|
230 |
-
st.session_state.processed_columns,
|
231 |
-
key="treat_col_select"
|
232 |
-
)
|
233 |
-
with col2:
|
234 |
-
outcome_col = st.selectbox(
|
235 |
-
"Select Outcome Variable:",
|
236 |
-
st.session_state.processed_columns,
|
237 |
-
key="outcome_col_select"
|
238 |
-
)
|
239 |
-
|
240 |
-
all_cols_except_treat_outcome = [col for col in st.session_state.processed_columns if col not in [treatment_col, outcome_col]]
|
241 |
-
covariates = st.multiselect(
|
242 |
-
"Select Covariates (confounders):",
|
243 |
-
all_cols_except_treat_outcome,
|
244 |
-
default=all_cols_except_treat_outcome, # Default to all other columns
|
245 |
-
key="covariates_select"
|
246 |
-
)
|
247 |
-
|
248 |
-
estimation_method = st.selectbox(
|
249 |
-
"Select Estimation Method:",
|
250 |
-
(
|
251 |
-
"Linear Regression ATE",
|
252 |
-
"Propensity Score Matching - Placeholder",
|
253 |
-
"Inverse Propensity Weighting - Placeholder",
|
254 |
-
"T-learner - Placeholder",
|
255 |
-
"S-learner - Placeholder"
|
256 |
-
)
|
257 |
-
)
|
258 |
-
|
259 |
-
if st.button("Estimate Treatment Effect"):
|
260 |
-
st.info(f"Estimating treatment effect using {estimation_method}...")
|
261 |
-
method_map = {
|
262 |
-
"Linear Regression ATE": "linear_regression",
|
263 |
-
"Propensity Score Matching - Placeholder": "propensity_score_matching",
|
264 |
-
"Inverse Propensity Weighting - Placeholder": "inverse_propensity_weighting",
|
265 |
-
"T-learner - Placeholder": "t_learner",
|
266 |
-
"S-learner - Placeholder": "s_learner"
|
267 |
-
}
|
268 |
-
selected_method_code = method_map[estimation_method]
|
269 |
-
|
270 |
-
try:
|
271 |
-
response = requests.post(
|
272 |
-
f"{FLASK_API_URL}/treatment/estimate_ate",
|
273 |
-
json={
|
274 |
-
"data": st.session_state.processed_data,
|
275 |
-
"treatment_col": treatment_col,
|
276 |
-
"outcome_col": outcome_col,
|
277 |
-
"covariates": covariates,
|
278 |
-
"method": selected_method_code
|
279 |
-
}
|
280 |
-
)
|
281 |
-
if response.status_code == 200:
|
282 |
-
ate_result = response.json()['result']
|
283 |
-
st.success(f"Treatment effect estimated using {estimation_method}:")
|
284 |
-
st.write(f"**Estimated ATE: {ate_result:.4f}**")
|
285 |
-
st.markdown("""
|
286 |
-
**Treatment Effect Explanation:**
|
287 |
-
* **Average Treatment Effect (ATE):** Measures the average causal effect of a treatment (e.g., `StudyHours`) on an outcome (e.g., `FinalExamScore`) across the entire population.
|
288 |
-
* It answers "How much does doing X cause a change in Y?".
|
289 |
-
* This estimation attempts to control for confounders (variables that influence both treatment and outcome) to isolate the true causal effect.
|
290 |
-
""")
|
291 |
-
else:
|
292 |
-
st.error(f"Error during ATE estimation: {response.json().get('detail', 'Unknown error')}")
|
293 |
-
except requests.exceptions.ConnectionError:
|
294 |
-
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
295 |
-
except Exception as e:
|
296 |
-
st.error(f"An unexpected error occurred during ATE estimation: {e}")
|
297 |
-
else:
|
298 |
-
st.info("Please preprocess data first to estimate treatment effects.")
|
299 |
-
|
300 |
-
# ---
|
301 |
-
st.header("
|
302 |
-
st.
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
st.info("Developed by CausalBox Team. For support, please contact us.")
|
|
|
1 |
+
# streamlit_app.py
|
2 |
+
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
+
import requests
|
5 |
+
import json
|
6 |
+
import plotly.express as px
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
import numpy as np # For random array in placeholders
|
9 |
+
import os
|
10 |
+
|
11 |
+
# Configuration
|
12 |
+
FLASK_API_URL = "http://localhost:5000" # Ensure this matches your Flask app's host and port
|
13 |
+
|
14 |
+
st.set_page_config(layout="wide", page_title="CausalBox Toolkit")
|
15 |
+
|
16 |
+
st.title("🔬 CausalBox: A Causal Inference Toolkit")
|
17 |
+
st.markdown("Uncover causal relationships, simulate interventions, and estimate treatment effects.")
|
18 |
+
|
19 |
+
# --- Session State Initialization ---
|
20 |
+
if 'processed_data' not in st.session_state:
|
21 |
+
st.session_state.processed_data = None
|
22 |
+
if 'processed_columns' not in st.session_state:
|
23 |
+
st.session_state.processed_columns = None
|
24 |
+
if 'causal_graph_adj' not in st.session_state:
|
25 |
+
st.session_state.causal_graph_adj = None
|
26 |
+
if 'causal_graph_nodes' not in st.session_state:
|
27 |
+
st.session_state.causal_graph_nodes = None
|
28 |
+
|
29 |
+
# --- Data Preprocessing Module ---
|
30 |
+
st.header("1. Data Preprocessor 🧹")
|
31 |
+
st.write("Upload your CSV dataset or use a generated sample dataset.")
|
32 |
+
|
33 |
+
# Option to use generated sample dataset
|
34 |
+
if st.button("Use Sample Dataset (sample_dataset.csv)"):
|
35 |
+
# In a real scenario, Streamlit would serve the file or you'd load it directly if local.
|
36 |
+
# For this setup, we assume the Flask backend can access it or you manually upload it once.
|
37 |
+
# For demonstration, we'll simulate loading a generic DataFrame.
|
38 |
+
# In a full deployment, you'd have a mechanism to either:
|
39 |
+
# a) Have Flask serve the sample file, or
|
40 |
+
# b) Directly load it in Streamlit if the app and data are co-located.
|
41 |
+
try:
|
42 |
+
# Assuming the sample dataset is accessible or you are testing locally with `scripts/generate_data.py`
|
43 |
+
# and then manually uploading this generated file.
|
44 |
+
# For simplicity, we'll create a dummy df here if not actually uploaded.
|
45 |
+
sample_df_path = "data/sample_dataset.csv" # Path relative to main.py or Streamlit app execution
|
46 |
+
if os.path.exists(sample_df_path):
|
47 |
+
sample_df = pd.read_csv(sample_df_path)
|
48 |
+
st.success(f"Loaded sample dataset from {sample_df_path}. Please upload this file if running from different directory.")
|
49 |
+
else:
|
50 |
+
st.warning("Sample dataset not found at data/sample_dataset.csv.")
|
51 |
+
# Dummy DataFrame for demonstration if sample file isn't found
|
52 |
+
sample_df = pd.DataFrame(np.random.rand(10, 5), columns=[f'col_{i}' for i in range(5)])
|
53 |
+
|
54 |
+
# Convert to JSON for Flask API call
|
55 |
+
files = {'file': ('sample_dataset.csv', sample_df.to_csv(index=False), 'text/csv')}
|
56 |
+
response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
|
57 |
+
|
58 |
+
if response.status_code == 200:
|
59 |
+
result = response.json()
|
60 |
+
st.session_state.processed_data = result['data']
|
61 |
+
st.session_state.processed_columns = result['columns']
|
62 |
+
st.success("Sample dataset preprocessed successfully!")
|
63 |
+
st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
|
64 |
+
else:
|
65 |
+
st.error(f"Error preprocessing sample dataset: {response.json().get('detail', 'Unknown error')}")
|
66 |
+
except Exception as e:
|
67 |
+
st.error(f"Could not load or process sample dataset: {e}")
|
68 |
+
|
69 |
+
|
70 |
+
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
|
71 |
+
if uploaded_file is not None:
|
72 |
+
st.info("Uploading and preprocessing data...")
|
73 |
+
files = {'file': (uploaded_file.name, uploaded_file.getvalue(), 'text/csv')}
|
74 |
+
try:
|
75 |
+
response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
|
76 |
+
if response.status_code == 200:
|
77 |
+
result = response.json()
|
78 |
+
st.session_state.processed_data = result['data']
|
79 |
+
st.session_state.processed_columns = result['columns']
|
80 |
+
st.success("File preprocessed successfully!")
|
81 |
+
st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
|
82 |
+
else:
|
83 |
+
st.error(f"Error during preprocessing: {response.json().get('detail', 'Unknown error')}")
|
84 |
+
except requests.exceptions.ConnectionError:
|
85 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
86 |
+
except Exception as e:
|
87 |
+
st.error(f"An unexpected error occurred: {e}")
|
88 |
+
|
89 |
+
# --- Causal Discovery Module ---
|
90 |
+
st.header("2. Causal Discovery 🕵️♂️")
|
91 |
+
if st.session_state.processed_data:
|
92 |
+
st.write("Learn the causal structure from your preprocessed data.")
|
93 |
+
|
94 |
+
discovery_algo = st.selectbox(
|
95 |
+
"Select Causal Discovery Algorithm:",
|
96 |
+
("PC Algorithm", "GES (Greedy Equivalence Search) - Placeholder", "NOTEARS - Placeholder")
|
97 |
+
)
|
98 |
+
|
99 |
+
if st.button("Discover Causal Graph"):
|
100 |
+
st.info(f"Discovering graph using {discovery_algo}...")
|
101 |
+
algo_map = {
|
102 |
+
"PC Algorithm": "pc",
|
103 |
+
"GES (Greedy Equivalence Search) - Placeholder": "ges",
|
104 |
+
"NOTEARS - Placeholder": "notears"
|
105 |
+
}
|
106 |
+
selected_algo_code = algo_map[discovery_algo]
|
107 |
+
|
108 |
+
try:
|
109 |
+
response = requests.post(
|
110 |
+
f"{FLASK_API_URL}/discover/",
|
111 |
+
json={"data": st.session_state.processed_data, "algorithm": selected_algo_code}
|
112 |
+
)
|
113 |
+
if response.status_code == 200:
|
114 |
+
result = response.json()
|
115 |
+
st.session_state.causal_graph_adj = result['graph']
|
116 |
+
st.session_state.causal_graph_nodes = st.session_state.processed_columns
|
117 |
+
st.success("Causal graph discovered!")
|
118 |
+
st.subheader("Causal Graph Visualization")
|
119 |
+
# Visualization will be handled by the Causal Graph Visualizer section
|
120 |
+
else:
|
121 |
+
st.error(f"Error during causal discovery: {response.json().get('detail', 'Unknown error')}")
|
122 |
+
except requests.exceptions.ConnectionError:
|
123 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
124 |
+
except Exception as e:
|
125 |
+
st.error(f"An unexpected error occurred: {e}")
|
126 |
+
else:
|
127 |
+
st.info("Please preprocess data first to enable causal discovery.")
|
128 |
+
|
129 |
+
# --- Causal Graph Visualizer Module ---
|
130 |
+
st.header("3. Causal Graph Visualizer 📊")
|
131 |
+
if st.session_state.causal_graph_adj and st.session_state.causal_graph_nodes:
|
132 |
+
st.write("Interactive visualization of the discovered causal graph.")
|
133 |
+
try:
|
134 |
+
response = requests.post(
|
135 |
+
f"{FLASK_API_URL}/visualize/graph",
|
136 |
+
json={"graph": st.session_state.causal_graph_adj, "nodes": st.session_state.causal_graph_nodes}
|
137 |
+
)
|
138 |
+
if response.status_code == 200:
|
139 |
+
graph_json = response.json()['graph']
|
140 |
+
fig = go.Figure(json.loads(graph_json))
|
141 |
+
st.plotly_chart(fig, use_container_width=True)
|
142 |
+
st.markdown("""
|
143 |
+
**Graph Explanation:**
|
144 |
+
* **Nodes:** Represent variables in your dataset.
|
145 |
+
* **Arrows (Edges):** Indicate a direct causal influence from one variable (the tail) to another (the head).
|
146 |
+
* **No Arrow:** Suggests no direct causal relationship was found, or the relationship is mediated by other variables.
|
147 |
+
|
148 |
+
This graph helps answer "Why did it happen?" by showing the structural relationships.
|
149 |
+
""")
|
150 |
+
else:
|
151 |
+
st.error(f"Error visualizing graph: {response.json().get('detail', 'Unknown error')}")
|
152 |
+
except requests.exceptions.ConnectionError:
|
153 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
154 |
+
except Exception as e:
|
155 |
+
st.error(f"An unexpected error occurred during visualization: {e}")
|
156 |
+
else:
|
157 |
+
st.info("Please discover a causal graph first to visualize it.")
|
158 |
+
|
159 |
+
|
160 |
+
# --- Do-Calculus Engine Module ---
|
161 |
+
st.header("4. Do-Calculus Engine 🧪")
|
162 |
+
if st.session_state.processed_data and st.session_state.causal_graph_adj:
|
163 |
+
st.write("Simulate interventions and observe their effects based on the causal graph.")
|
164 |
+
|
165 |
+
intervention_var = st.selectbox(
|
166 |
+
"Select variable to intervene on:",
|
167 |
+
st.session_state.processed_columns,
|
168 |
+
key="inter_var_select"
|
169 |
+
)
|
170 |
+
# Attempt to infer type for intervention_value input
|
171 |
+
# Simplified approach: assuming numerical for now due to preprocessor output
|
172 |
+
if intervention_var and isinstance(st.session_state.processed_data[0][intervention_var], (int, float)):
|
173 |
+
intervention_value = st.number_input(f"Set '{intervention_var}' to value:", key="inter_val_input")
|
174 |
+
else: # Treat as string/categorical for input, then try to preprocess for API
|
175 |
+
intervention_value = st.text_input(f"Set '{intervention_var}' to value:", key="inter_val_input_text")
|
176 |
+
st.warning("Categorical intervention values might require specific encoding logic on the backend.")
|
177 |
+
|
178 |
+
if st.button("Perform Intervention"):
|
179 |
+
st.info(f"Performing intervention: do('{intervention_var}' = {intervention_value})...")
|
180 |
+
try:
|
181 |
+
response = requests.post(
|
182 |
+
f"{FLASK_API_URL}/intervene/",
|
183 |
+
json={
|
184 |
+
"data": st.session_state.processed_data,
|
185 |
+
"intervention_var": intervention_var,
|
186 |
+
"intervention_value": intervention_value,
|
187 |
+
"graph": st.session_state.causal_graph_adj # Pass graph for advanced do-calculus
|
188 |
+
}
|
189 |
+
)
|
190 |
+
if response.status_code == 200:
|
191 |
+
intervened_data = pd.DataFrame(response.json()['intervened_data'])
|
192 |
+
st.success("Intervention simulated successfully!")
|
193 |
+
st.subheader("Intervened Data (First 10 rows)")
|
194 |
+
st.dataframe(intervened_data.head(10))
|
195 |
+
|
196 |
+
# Simple comparison visualization (e.g., histogram of outcome variable)
|
197 |
+
if st.session_state.processed_columns and 'FinalExamScore' in st.session_state.processed_columns:
|
198 |
+
original_df = pd.DataFrame(st.session_state.processed_data)
|
199 |
+
fig_dist = go.Figure()
|
200 |
+
fig_dist.add_trace(go.Histogram(x=original_df['FinalExamScore'], name='Original', opacity=0.7))
|
201 |
+
fig_dist.add_trace(go.Histogram(x=intervened_data['FinalExamScore'], name='Intervened', opacity=0.0))
|
202 |
+
|
203 |
+
st.plotly_chart(fig_dist, use_container_width=True)
|
204 |
+
st.markdown("""
|
205 |
+
**Intervention Explanation:**
|
206 |
+
* By simulating `do(X=x)`, we are forcing the value of X, effectively breaking its causal links from its parents.
|
207 |
+
* The graph above shows the distribution of a key outcome variable (e.g., `FinalExamScore`) before and after the intervention.
|
208 |
+
* This helps answer "What if we do this instead?" by showing the predicted outcome.
|
209 |
+
""")
|
210 |
+
else:
|
211 |
+
st.info("Consider adding a relevant outcome variable to your dataset for better intervention analysis.")
|
212 |
+
else:
|
213 |
+
st.error(f"Error during intervention: {response.json().get('detail', 'Unknown error')}")
|
214 |
+
except requests.exceptions.ConnectionError:
|
215 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
216 |
+
except Exception as e:
|
217 |
+
st.error(f"An unexpected error occurred during intervention: {e}")
|
218 |
+
else:
|
219 |
+
st.info("Please preprocess data and discover a causal graph first to perform interventions.")
|
220 |
+
|
221 |
+
# --- Treatment Effect Estimator Module ---
|
222 |
+
st.header("5. Treatment Effect Estimator 🎯")
|
223 |
+
if st.session_state.processed_data:
|
224 |
+
st.write("Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).")
|
225 |
+
|
226 |
+
col1, col2 = st.columns(2)
|
227 |
+
with col1:
|
228 |
+
treatment_col = st.selectbox(
|
229 |
+
"Select Treatment Variable:",
|
230 |
+
st.session_state.processed_columns,
|
231 |
+
key="treat_col_select"
|
232 |
+
)
|
233 |
+
with col2:
|
234 |
+
outcome_col = st.selectbox(
|
235 |
+
"Select Outcome Variable:",
|
236 |
+
st.session_state.processed_columns,
|
237 |
+
key="outcome_col_select"
|
238 |
+
)
|
239 |
+
|
240 |
+
all_cols_except_treat_outcome = [col for col in st.session_state.processed_columns if col not in [treatment_col, outcome_col]]
|
241 |
+
covariates = st.multiselect(
|
242 |
+
"Select Covariates (confounders):",
|
243 |
+
all_cols_except_treat_outcome,
|
244 |
+
default=all_cols_except_treat_outcome, # Default to all other columns
|
245 |
+
key="covariates_select"
|
246 |
+
)
|
247 |
+
|
248 |
+
estimation_method = st.selectbox(
|
249 |
+
"Select Estimation Method:",
|
250 |
+
(
|
251 |
+
"Linear Regression ATE",
|
252 |
+
"Propensity Score Matching - Placeholder",
|
253 |
+
"Inverse Propensity Weighting - Placeholder",
|
254 |
+
"T-learner - Placeholder",
|
255 |
+
"S-learner - Placeholder"
|
256 |
+
)
|
257 |
+
)
|
258 |
+
|
259 |
+
if st.button("Estimate Treatment Effect"):
|
260 |
+
st.info(f"Estimating treatment effect using {estimation_method}...")
|
261 |
+
method_map = {
|
262 |
+
"Linear Regression ATE": "linear_regression",
|
263 |
+
"Propensity Score Matching - Placeholder": "propensity_score_matching",
|
264 |
+
"Inverse Propensity Weighting - Placeholder": "inverse_propensity_weighting",
|
265 |
+
"T-learner - Placeholder": "t_learner",
|
266 |
+
"S-learner - Placeholder": "s_learner"
|
267 |
+
}
|
268 |
+
selected_method_code = method_map[estimation_method]
|
269 |
+
|
270 |
+
try:
|
271 |
+
response = requests.post(
|
272 |
+
f"{FLASK_API_URL}/treatment/estimate_ate",
|
273 |
+
json={
|
274 |
+
"data": st.session_state.processed_data,
|
275 |
+
"treatment_col": treatment_col,
|
276 |
+
"outcome_col": outcome_col,
|
277 |
+
"covariates": covariates,
|
278 |
+
"method": selected_method_code
|
279 |
+
}
|
280 |
+
)
|
281 |
+
if response.status_code == 200:
|
282 |
+
ate_result = response.json()['result']
|
283 |
+
st.success(f"Treatment effect estimated using {estimation_method}:")
|
284 |
+
st.write(f"**Estimated ATE: {ate_result:.4f}**")
|
285 |
+
st.markdown("""
|
286 |
+
**Treatment Effect Explanation:**
|
287 |
+
* **Average Treatment Effect (ATE):** Measures the average causal effect of a treatment (e.g., `StudyHours`) on an outcome (e.g., `FinalExamScore`) across the entire population.
|
288 |
+
* It answers "How much does doing X cause a change in Y?".
|
289 |
+
* This estimation attempts to control for confounders (variables that influence both treatment and outcome) to isolate the true causal effect.
|
290 |
+
""")
|
291 |
+
else:
|
292 |
+
st.error(f"Error during ATE estimation: {response.json().get('detail', 'Unknown error')}")
|
293 |
+
except requests.exceptions.ConnectionError:
|
294 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
295 |
+
except Exception as e:
|
296 |
+
st.error(f"An unexpected error occurred during ATE estimation: {e}")
|
297 |
+
else:
|
298 |
+
st.info("Please preprocess data first to estimate treatment effects.")
|
299 |
+
|
300 |
+
# --- Prediction Module ---
|
301 |
+
st.header("6. Prediction Module 📈")
|
302 |
+
if st.session_state.processed_data:
|
303 |
+
st.write("Train a machine learning model for prediction (Regression or Classification).")
|
304 |
+
|
305 |
+
prediction_type = st.selectbox(
|
306 |
+
"Select Prediction Type:",
|
307 |
+
("Regression", "Classification"),
|
308 |
+
key="prediction_type_select"
|
309 |
+
)
|
310 |
+
|
311 |
+
all_columns = st.session_state.processed_columns
|
312 |
+
|
313 |
+
suitable_target_columns = []
|
314 |
+
if st.session_state.processed_data:
|
315 |
+
temp_df = pd.DataFrame(st.session_state.processed_data)
|
316 |
+
for col in all_columns:
|
317 |
+
# For classification, check if column is object type (string), boolean,
|
318 |
+
# or has a limited number of unique integer values (e.g., less than 20 unique values)
|
319 |
+
if prediction_type == 'Classification':
|
320 |
+
if temp_df[col].dtype == 'object' or temp_df[col].dtype == 'bool':
|
321 |
+
suitable_target_columns.append(col)
|
322 |
+
elif pd.api.types.is_integer_dtype(temp_df[col]) and temp_df[col].nunique() < 20: # Heuristic for discrete integers
|
323 |
+
suitable_target_columns.append(col)
|
324 |
+
# For regression, primarily numerical columns
|
325 |
+
elif prediction_type == 'Regression':
|
326 |
+
if pd.api.types.is_numeric_dtype(temp_df[col]):
|
327 |
+
suitable_target_columns.append(col)
|
328 |
+
|
329 |
+
if not suitable_target_columns:
|
330 |
+
st.warning(f"No suitable target columns found for {prediction_type}. Please check your data types.")
|
331 |
+
target_col = None # Set to None to prevent error if no columns are found
|
332 |
+
else:
|
333 |
+
# Try to pre-select the currently chosen target_col if it's still suitable
|
334 |
+
# Otherwise, default to the first suitable column
|
335 |
+
if 'target_col_select' in st.session_state and st.session_state.target_col_select in suitable_target_columns:
|
336 |
+
default_target_index = suitable_target_columns.index(st.session_state.target_col_select)
|
337 |
+
else:
|
338 |
+
default_target_index = 0
|
339 |
+
|
340 |
+
target_col = st.selectbox(
|
341 |
+
"Select Target Variable:",
|
342 |
+
suitable_target_columns,
|
343 |
+
index=default_target_index,
|
344 |
+
key="target_col_select"
|
345 |
+
)
|
346 |
+
|
347 |
+
# Filter out the target column from feature options
|
348 |
+
feature_options = [col for col in all_columns if col != target_col]
|
349 |
+
feature_cols = st.multiselect(
|
350 |
+
"Select Feature Variables:",
|
351 |
+
feature_options,
|
352 |
+
default=feature_options, # Default to all other columns
|
353 |
+
key="feature_cols_select"
|
354 |
+
)
|
355 |
+
|
356 |
+
if st.button("Train Model & Predict", key="train_predict_button"):
|
357 |
+
if not target_col or not feature_cols:
|
358 |
+
st.warning("Please select a target variable and at least one feature variable.")
|
359 |
+
else:
|
360 |
+
st.info(f"Training {prediction_type} model using Random Forest...")
|
361 |
+
try:
|
362 |
+
response = requests.post(
|
363 |
+
f"{FLASK_API_URL}/prediction/train_predict",
|
364 |
+
json={
|
365 |
+
"data": st.session_state.processed_data,
|
366 |
+
"target_col": target_col,
|
367 |
+
"feature_cols": feature_cols,
|
368 |
+
"prediction_type": prediction_type.lower()
|
369 |
+
}
|
370 |
+
)
|
371 |
+
|
372 |
+
if response.status_code == 200:
|
373 |
+
results = response.json()['results']
|
374 |
+
st.success(f"{prediction_type} Model Trained Successfully!")
|
375 |
+
st.subheader("Model Performance")
|
376 |
+
|
377 |
+
if prediction_type == 'Regression':
|
378 |
+
st.write(f"**R-squared:** {results['r2_score']:.4f}")
|
379 |
+
st.write(f"**Mean Squared Error (MSE):** {results['mean_squared_error']:.4f}")
|
380 |
+
st.write(f"**Root Mean Squared Error (RMSE):** {results['root_mean_squared_error']:.4f}")
|
381 |
+
|
382 |
+
st.subheader("Actual vs. Predicted Plot")
|
383 |
+
actual_predicted_df = pd.DataFrame(results['actual_vs_predicted'])
|
384 |
+
fig_reg = px.scatter(actual_predicted_df, x='Actual', y='Predicted',
|
385 |
+
title='Actual vs. Predicted Values',
|
386 |
+
labels={'Actual': f'Actual {target_col}', 'Predicted': f'Predicted {target_col}'})
|
387 |
+
fig_reg.add_trace(go.Scatter(x=[actual_predicted_df['Actual'].min(), actual_predicted_df['Actual'].max()],
|
388 |
+
y=[actual_predicted_df['Actual'].min(), actual_predicted_df['Actual'].max()],
|
389 |
+
mode='lines', name='Ideal Fit', line=dict(dash='dash', color='red')))
|
390 |
+
st.plotly_chart(fig_reg, use_container_width=True)
|
391 |
+
|
392 |
+
st.subheader("Residual Plot")
|
393 |
+
actual_predicted_df['Residuals'] = actual_predicted_df['Actual'] - actual_predicted_df['Predicted']
|
394 |
+
fig_res = px.scatter(actual_predicted_df, x='Predicted', y='Residuals',
|
395 |
+
title='Residual Plot',
|
396 |
+
labels={'Predicted': f'Predicted {target_col}', 'Residuals': 'Residuals'})
|
397 |
+
fig_res.add_hline(y=0, line_dash="dash", line_color="red")
|
398 |
+
st.plotly_chart(fig_res, use_container_width=True)
|
399 |
+
|
400 |
+
elif prediction_type == 'Classification':
|
401 |
+
st.write(f"**Accuracy:** {results['accuracy']:.4f}")
|
402 |
+
st.write(f"**Precision (weighted):** {results['precision']:.4f}")
|
403 |
+
st.write(f"**Recall (weighted):** {results['recall']:.4f}")
|
404 |
+
st.write(f"**F1-Score (weighted):** {results['f1_score']:.4f}")
|
405 |
+
|
406 |
+
st.subheader("Confusion Matrix")
|
407 |
+
conf_matrix = results['confusion_matrix']
|
408 |
+
class_labels = results.get('class_labels', [str(i) for i in range(len(conf_matrix))])
|
409 |
+
fig_cm = px.imshow(conf_matrix,
|
410 |
+
labels=dict(x="Predicted", y="True", color="Count"),
|
411 |
+
x=class_labels,
|
412 |
+
y=class_labels,
|
413 |
+
text_auto=True,
|
414 |
+
color_continuous_scale="Viridis",
|
415 |
+
title="Confusion Matrix")
|
416 |
+
st.plotly_chart(fig_cm, use_container_width=True)
|
417 |
+
|
418 |
+
st.subheader("Classification Report")
|
419 |
+
# Convert dict to DataFrame for nice display
|
420 |
+
report_df = pd.DataFrame(results['classification_report']).transpose()
|
421 |
+
st.dataframe(report_df)
|
422 |
+
|
423 |
+
st.subheader("Feature Importances")
|
424 |
+
feature_importances_df = pd.DataFrame(list(results['feature_importances'].items()), columns=['Feature', 'Importance'])
|
425 |
+
fig_fi = px.bar(feature_importances_df, x='Importance', y='Feature', orientation='h',
|
426 |
+
title='Feature Importances',
|
427 |
+
labels={'Importance': 'Importance Score', 'Feature': 'Feature Name'})
|
428 |
+
fig_fi.update_layout(yaxis={'categoryorder':'total ascending'}) # Sort bars
|
429 |
+
st.plotly_chart(fig_fi, use_container_width=True)
|
430 |
+
else:
|
431 |
+
st.error(f"Error during prediction: {response.json().get('detail', 'Unknown error')}")
|
432 |
+
except requests.exceptions.ConnectionError:
|
433 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
434 |
+
except Exception as e:
|
435 |
+
st.error(f"An unexpected error occurred during prediction: {e}")
|
436 |
+
else:
|
437 |
+
st.info("Please preprocess data first to use the Prediction Module.")
|
438 |
+
|
439 |
+
# --- Time Series Causal Discovery Module ---
|
440 |
+
st.header("7. Time Series Causal Discovery ⏰")
|
441 |
+
if st.session_state.processed_data:
|
442 |
+
st.write("Infer causal relationships in time-series data using Granger Causality.")
|
443 |
+
st.info("Ensure your dataset includes a timestamp column and that variables are numeric.")
|
444 |
+
|
445 |
+
all_columns = st.session_state.processed_columns
|
446 |
+
|
447 |
+
# Heuristic to suggest potential timestamp columns (object/string type, or first column)
|
448 |
+
potential_ts_cols = [col for col in all_columns if pd.DataFrame(st.session_state.processed_data)[col].dtype == 'object']
|
449 |
+
if not potential_ts_cols and all_columns: # If no object columns, suggest the first column
|
450 |
+
potential_ts_cols = [all_columns[0]]
|
451 |
+
|
452 |
+
timestamp_col = st.selectbox(
|
453 |
+
"Select Timestamp Column:",
|
454 |
+
potential_ts_cols if potential_ts_cols else ["No suitable timestamp column found. Please check data."],
|
455 |
+
key="ts_col_select"
|
456 |
+
)
|
457 |
+
|
458 |
+
# Filter out timestamp column and non-numeric columns for analysis
|
459 |
+
variables_for_ts_analysis = [
|
460 |
+
col for col in all_columns if col != timestamp_col and pd.api.types.is_numeric_dtype(pd.DataFrame(st.session_state.processed_data)[col])
|
461 |
+
]
|
462 |
+
|
463 |
+
variables_to_analyze = st.multiselect(
|
464 |
+
"Select Variables to Analyze for Granger Causality:",
|
465 |
+
variables_for_ts_analysis,
|
466 |
+
default=variables_for_ts_analysis,
|
467 |
+
key="ts_vars_select"
|
468 |
+
)
|
469 |
+
|
470 |
+
max_lags = st.number_input(
|
471 |
+
"Max Lags (for Granger Causality):",
|
472 |
+
min_value=1,
|
473 |
+
value=5, # Default value
|
474 |
+
step=1,
|
475 |
+
help="The maximum number of lagged observations to consider for causality."
|
476 |
+
)
|
477 |
+
|
478 |
+
if st.button("Discover Time Series Causality", key="ts_discover_button"):
|
479 |
+
if not timestamp_col or not variables_to_analyze:
|
480 |
+
st.warning("Please select a timestamp column and at least one variable to analyze.")
|
481 |
+
elif "No suitable timestamp column found" in timestamp_col:
|
482 |
+
st.error("Cannot proceed. Please ensure your data has a suitable timestamp column.")
|
483 |
+
else:
|
484 |
+
st.info("Performing Granger Causality tests...")
|
485 |
+
try:
|
486 |
+
response = requests.post(
|
487 |
+
f"{FLASK_API_URL}/timeseries/discover_causality",
|
488 |
+
json={
|
489 |
+
"data": st.session_state.processed_data,
|
490 |
+
"timestamp_col": timestamp_col,
|
491 |
+
"variables_to_analyze": variables_to_analyze,
|
492 |
+
"max_lags": max_lags
|
493 |
+
}
|
494 |
+
)
|
495 |
+
|
496 |
+
if response.status_code == 200:
|
497 |
+
results = response.json()['results']
|
498 |
+
st.success("Time Series Causal Discovery Complete!")
|
499 |
+
st.subheader("Granger Causality Test Results")
|
500 |
+
|
501 |
+
if results:
|
502 |
+
# Convert results to a DataFrame for better display
|
503 |
+
results_df = pd.DataFrame(results)
|
504 |
+
results_df['p_value'] = results_df['p_value'].round(4) # Round p-values
|
505 |
+
st.dataframe(results_df)
|
506 |
+
|
507 |
+
st.markdown("**Interpretation:** A small p-value (typically < 0.05) suggests that the 'cause' variable Granger-causes the 'effect' variable. This means past values of the 'cause' variable help predict future values of the 'effect' variable, even when past values of the 'effect' variable are considered.")
|
508 |
+
st.markdown(f"*(Note: Granger Causality implies predictive causality, not necessarily true mechanistic causality. Also, ensure your time series are stationary for robust results.)*")
|
509 |
+
|
510 |
+
# Optionally, visualize a simple causality graph
|
511 |
+
st.subheader("Granger Causality Graph")
|
512 |
+
fig_ts_graph = go.Figure()
|
513 |
+
nodes = []
|
514 |
+
edges = []
|
515 |
+
edge_colors = []
|
516 |
+
|
517 |
+
# Add nodes
|
518 |
+
for i, var in enumerate(variables_to_analyze):
|
519 |
+
nodes.append(dict(id=var, label=var, x=np.cos(i*2*np.pi/len(variables_to_analyze)), y=np.sin(i*2*np.pi/len(variables_to_analyze))))
|
520 |
+
|
521 |
+
# Add edges
|
522 |
+
for res in results:
|
523 |
+
if res['p_value'] < 0.05: # Consider it a causal link if p-value is below significance
|
524 |
+
edges.append(dict(source=res['cause'], target=res['effect'], value=1/res['p_value'], title=f"p={res['p_value']:.4f}"))
|
525 |
+
edge_colors.append("blue")
|
526 |
+
else:
|
527 |
+
# Optional: Show non-significant edges in a different color or omit
|
528 |
+
pass
|
529 |
+
|
530 |
+
# Use a simple network graph layout (Spring layout is common)
|
531 |
+
# For a truly interactive graph, you might need a different library or more complex Plotly setup
|
532 |
+
# This is a very basic attempt to visualize; consider more robust solutions like NetworkX + Plotly/Dash
|
533 |
+
|
534 |
+
# Simple way to draw arrows for significant relationships
|
535 |
+
significant_edges = [edge for edge in results if edge['p_value'] < 0.05]
|
536 |
+
if significant_edges:
|
537 |
+
st.write("Visualizing significant (p < 0.05) Granger causal links:")
|
538 |
+
# This needs a more robust way to draw directed edges in plotly if using just scatter/lines.
|
539 |
+
# For now, let's just list them clearly.
|
540 |
+
for edge in significant_edges:
|
541 |
+
st.write(f"➡️ **{edge['cause']}** Granger-causes **{edge['effect']}** (p={edge['p_value']:.4f})")
|
542 |
+
else:
|
543 |
+
st.info("No significant Granger causal links found at p < 0.05.")
|
544 |
+
|
545 |
+
else:
|
546 |
+
st.info("No Granger Causality relationships found or data insufficient.")
|
547 |
+
|
548 |
+
else:
|
549 |
+
st.error(f"Error during time-series causal discovery: {response.json().get('detail', 'Unknown error')}")
|
550 |
+
except requests.exceptions.ConnectionError:
|
551 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
552 |
+
except Exception as e:
|
553 |
+
st.error(f"An unexpected error occurred during time-series causal discovery: {e}")
|
554 |
+
else:
|
555 |
+
st.info("Please preprocess data first to use the Time Series Causal Discovery Module.")
|
556 |
+
|
557 |
+
# --- CausalBox Chat Assistant ---
|
558 |
+
st.header("8. CausalBox Chat Assistant 🤖")
|
559 |
+
st.write("Ask questions about your loaded dataset, causal concepts, or the discovered causal graph!")
|
560 |
+
|
561 |
+
# Initialize chat history in session state
|
562 |
+
if "messages" not in st.session_state:
|
563 |
+
st.session_state.messages = []
|
564 |
+
|
565 |
+
# Display chat messages from history on app rerun
|
566 |
+
for message in st.session_state.messages:
|
567 |
+
with st.chat_message(message["role"]):
|
568 |
+
st.markdown(message["content"])
|
569 |
+
|
570 |
+
# Accept user input
|
571 |
+
if prompt := st.chat_input("Ask me anything about CausalBox..."):
|
572 |
+
# Add user message to chat history
|
573 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
574 |
+
# Display user message in chat message container
|
575 |
+
with st.chat_message("user"):
|
576 |
+
st.markdown(prompt)
|
577 |
+
|
578 |
+
# Prepare session context to send to the backend
|
579 |
+
session_context = {
|
580 |
+
"processed_data": st.session_state.processed_data,
|
581 |
+
"processed_columns": st.session_state.processed_columns,
|
582 |
+
"causal_graph_adj": st.session_state.causal_graph_adj,
|
583 |
+
"causal_graph_nodes": st.session_state.causal_graph_nodes,
|
584 |
+
# Add any other relevant session state variables that the chatbot might need
|
585 |
+
}
|
586 |
+
|
587 |
+
with st.spinner("Thinking..."):
|
588 |
+
try:
|
589 |
+
response = requests.post(
|
590 |
+
f"{FLASK_API_URL}/chatbot/message",
|
591 |
+
json={
|
592 |
+
"user_message": prompt,
|
593 |
+
"session_context": session_context
|
594 |
+
}
|
595 |
+
)
|
596 |
+
|
597 |
+
if response.status_code == 200:
|
598 |
+
chatbot_response_text = response.json().get('response', 'Sorry, I could not generate a response.')
|
599 |
+
else:
|
600 |
+
chatbot_response_text = f"Error from chatbot backend: {response.json().get('detail', 'Unknown error')}"
|
601 |
+
except requests.exceptions.ConnectionError:
|
602 |
+
chatbot_response_text = f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running."
|
603 |
+
except Exception as e:
|
604 |
+
chatbot_response_text = f"An unexpected error occurred while getting chatbot response: {e}"
|
605 |
+
|
606 |
+
# Display assistant response in chat message container
|
607 |
+
with st.chat_message("assistant"):
|
608 |
+
st.markdown(chatbot_response_text)
|
609 |
+
# Add assistant response to chat history
|
610 |
+
st.session_state.messages.append({"role": "assistant", "content": chatbot_response_text})
|
611 |
+
|
612 |
+
# --- Future Work (Simplified) ---
|
613 |
+
st.header("Future Work 🚀")
|
614 |
+
st.markdown("""
|
615 |
+
- **🔄 Auto-causal graph refresh:** Monitor dataset updates and automatically refresh the causal graph.
|
616 |
+
""")
|
617 |
+
|
618 |
+
st.markdown("---")
|
619 |
st.info("Developed by CausalBox Team. For support, please contact us.")
|
utils/__pycache__/casual_algorithms.cpython-310.pyc
CHANGED
Binary files a/utils/__pycache__/casual_algorithms.cpython-310.pyc and b/utils/__pycache__/casual_algorithms.cpython-310.pyc differ
|
|
utils/__pycache__/causal_chatbot.cpython-310.pyc
ADDED
Binary file (10.2 kB). View file
|
|
utils/__pycache__/do_calculus.cpython-310.pyc
CHANGED
Binary files a/utils/__pycache__/do_calculus.cpython-310.pyc and b/utils/__pycache__/do_calculus.cpython-310.pyc differ
|
|
utils/__pycache__/graph_utils.cpython-310.pyc
CHANGED
Binary files a/utils/__pycache__/graph_utils.cpython-310.pyc and b/utils/__pycache__/graph_utils.cpython-310.pyc differ
|
|
utils/__pycache__/prediction_models.cpython-310.pyc
ADDED
Binary file (3.09 kB). View file
|
|
utils/__pycache__/preprocessor.cpython-310.pyc
CHANGED
Binary files a/utils/__pycache__/preprocessor.cpython-310.pyc and b/utils/__pycache__/preprocessor.cpython-310.pyc differ
|
|
utils/__pycache__/time_series_causal.cpython-310.pyc
ADDED
Binary file (2.21 kB). View file
|
|
utils/__pycache__/treatment_effects.cpython-310.pyc
CHANGED
Binary files a/utils/__pycache__/treatment_effects.cpython-310.pyc and b/utils/__pycache__/treatment_effects.cpython-310.pyc differ
|
|
utils/casual_algorithms.py
CHANGED
@@ -1,64 +1,64 @@
|
|
1 |
-
# utils/causal_algorithms.py
|
2 |
-
import networkx as nx
|
3 |
-
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
-
from causallearn.search.ConstraintBased.PC import pc
|
6 |
-
# from causallearn.search.ScoreBased.GES import ges # Example import for GES
|
7 |
-
# from notears import notears_linear # Example import for NOTEARS
|
8 |
-
|
9 |
-
class CausalDiscoveryAlgorithms:
|
10 |
-
def pc_algorithm(self, df, alpha=0.05):
|
11 |
-
"""
|
12 |
-
Run PC algorithm to learn causal graph.
|
13 |
-
Returns a directed graph's adjacency matrix.
|
14 |
-
Requires numerical data.
|
15 |
-
"""
|
16 |
-
data_array = df.to_numpy()
|
17 |
-
cg = pc(data_array, alpha=alpha, indep_test="fisherz")
|
18 |
-
adj_matrix = cg.G.graph
|
19 |
-
return adj_matrix
|
20 |
-
|
21 |
-
def ges_algorithm(self, df):
|
22 |
-
"""
|
23 |
-
Placeholder for GES (Greedy Equivalence Search) algorithm.
|
24 |
-
Returns a directed graph's adjacency matrix.
|
25 |
-
You would implement or integrate the GES algorithm here.
|
26 |
-
"""
|
27 |
-
# Example: G, edges = ges(data_array)
|
28 |
-
# For now, returning a simplified correlation-based graph for demonstration
|
29 |
-
print("GES algorithm is a placeholder. Using a simplified correlation-based graph.")
|
30 |
-
G = nx.DiGraph()
|
31 |
-
nodes = df.columns
|
32 |
-
G.add_nodes_from(nodes)
|
33 |
-
corr_matrix = df.corr().abs()
|
34 |
-
threshold = 0.3
|
35 |
-
for i, col1 in enumerate(nodes):
|
36 |
-
for col2 in nodes[i+1:]:
|
37 |
-
if corr_matrix.loc[col1, col2] > threshold:
|
38 |
-
if np.random.rand() > 0.5:
|
39 |
-
G.add_edge(col1, col2)
|
40 |
-
else:
|
41 |
-
G.add_edge(col2, col1)
|
42 |
-
return nx.to_numpy_array(G) # Convert to adjacency matrix
|
43 |
-
|
44 |
-
def notears_algorithm(self, df):
|
45 |
-
"""
|
46 |
-
Placeholder for NOTEARS algorithm.
|
47 |
-
Returns a directed graph's adjacency matrix.
|
48 |
-
You would implement or integrate the NOTEARS algorithm here.
|
49 |
-
"""
|
50 |
-
# Example: W_est = notears_linear(data_array)
|
51 |
-
print("NOTEARS algorithm is a placeholder. Using a simplified correlation-based graph.")
|
52 |
-
G = nx.DiGraph()
|
53 |
-
nodes = df.columns
|
54 |
-
G.add_nodes_from(nodes)
|
55 |
-
corr_matrix = df.corr().abs()
|
56 |
-
threshold = 0.3
|
57 |
-
for i, col1 in enumerate(nodes):
|
58 |
-
for col2 in nodes[i+1:]:
|
59 |
-
if corr_matrix.loc[col1, col2] > threshold:
|
60 |
-
if np.random.rand() > 0.5:
|
61 |
-
G.add_edge(col1, col2)
|
62 |
-
else:
|
63 |
-
G.add_edge(col2, col1)
|
64 |
return nx.to_numpy_array(G) # Convert to adjacency matrix
|
|
|
1 |
+
# utils/causal_algorithms.py
|
2 |
+
import networkx as nx
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from causallearn.search.ConstraintBased.PC import pc
|
6 |
+
# from causallearn.search.ScoreBased.GES import ges # Example import for GES
|
7 |
+
# from notears import notears_linear # Example import for NOTEARS
|
8 |
+
|
9 |
+
class CausalDiscoveryAlgorithms:
|
10 |
+
def pc_algorithm(self, df, alpha=0.05):
|
11 |
+
"""
|
12 |
+
Run PC algorithm to learn causal graph.
|
13 |
+
Returns a directed graph's adjacency matrix.
|
14 |
+
Requires numerical data.
|
15 |
+
"""
|
16 |
+
data_array = df.to_numpy()
|
17 |
+
cg = pc(data_array, alpha=alpha, indep_test="fisherz")
|
18 |
+
adj_matrix = cg.G.graph
|
19 |
+
return adj_matrix
|
20 |
+
|
21 |
+
def ges_algorithm(self, df):
|
22 |
+
"""
|
23 |
+
Placeholder for GES (Greedy Equivalence Search) algorithm.
|
24 |
+
Returns a directed graph's adjacency matrix.
|
25 |
+
You would implement or integrate the GES algorithm here.
|
26 |
+
"""
|
27 |
+
# Example: G, edges = ges(data_array)
|
28 |
+
# For now, returning a simplified correlation-based graph for demonstration
|
29 |
+
print("GES algorithm is a placeholder. Using a simplified correlation-based graph.")
|
30 |
+
G = nx.DiGraph()
|
31 |
+
nodes = df.columns
|
32 |
+
G.add_nodes_from(nodes)
|
33 |
+
corr_matrix = df.corr().abs()
|
34 |
+
threshold = 0.3
|
35 |
+
for i, col1 in enumerate(nodes):
|
36 |
+
for col2 in nodes[i+1:]:
|
37 |
+
if corr_matrix.loc[col1, col2] > threshold:
|
38 |
+
if np.random.rand() > 0.5:
|
39 |
+
G.add_edge(col1, col2)
|
40 |
+
else:
|
41 |
+
G.add_edge(col2, col1)
|
42 |
+
return nx.to_numpy_array(G) # Convert to adjacency matrix
|
43 |
+
|
44 |
+
def notears_algorithm(self, df):
|
45 |
+
"""
|
46 |
+
Placeholder for NOTEARS algorithm.
|
47 |
+
Returns a directed graph's adjacency matrix.
|
48 |
+
You would implement or integrate the NOTEARS algorithm here.
|
49 |
+
"""
|
50 |
+
# Example: W_est = notears_linear(data_array)
|
51 |
+
print("NOTEARS algorithm is a placeholder. Using a simplified correlation-based graph.")
|
52 |
+
G = nx.DiGraph()
|
53 |
+
nodes = df.columns
|
54 |
+
G.add_nodes_from(nodes)
|
55 |
+
corr_matrix = df.corr().abs()
|
56 |
+
threshold = 0.3
|
57 |
+
for i, col1 in enumerate(nodes):
|
58 |
+
for col2 in nodes[i+1:]:
|
59 |
+
if corr_matrix.loc[col1, col2] > threshold:
|
60 |
+
if np.random.rand() > 0.5:
|
61 |
+
G.add_edge(col1, col2)
|
62 |
+
else:
|
63 |
+
G.add_edge(col2, col1)
|
64 |
return nx.to_numpy_array(G) # Convert to adjacency matrix
|
utils/causal_chatbot.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/causal_chatbot.py
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from langchain_groq import ChatGroq
|
5 |
+
from langchain_core.tools import tool
|
6 |
+
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
|
7 |
+
from langchain_core.prompts import ChatPromptTemplate
|
8 |
+
from utils.preprocessor import summarize_dataframe_for_chatbot
|
9 |
+
from utils.graph_utils import get_graph_summary_for_chatbot
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
# Configure Groq API Key
|
15 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
16 |
+
if not GROQ_API_KEY:
|
17 |
+
print("ERROR: GROQ_API_KEY environment variable not set.")
|
18 |
+
raise ValueError("GROQ_API_KEY is required.")
|
19 |
+
|
20 |
+
# Debug: Print API key details
|
21 |
+
print(f"Loaded GROQ_API_KEY: {GROQ_API_KEY[:5]}...{GROQ_API_KEY[-5:]}")
|
22 |
+
print(f"API Key Length: {len(GROQ_API_KEY)}")
|
23 |
+
|
24 |
+
# Initialize the Groq model with LangChain
|
25 |
+
try:
|
26 |
+
model = ChatGroq(
|
27 |
+
model_name="llama-3.3-70b-versatile",
|
28 |
+
temperature=0.7,
|
29 |
+
groq_api_key=GROQ_API_KEY
|
30 |
+
)
|
31 |
+
except Exception as e:
|
32 |
+
print(f"Error configuring Groq API: {e}")
|
33 |
+
model = None
|
34 |
+
|
35 |
+
def assess_causal_compatibility(data_json: list) -> str:
|
36 |
+
"""
|
37 |
+
Assesses the dataset's compatibility for causal inference analysis.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
data_json: List of dictionaries representing the dataset.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
String describing the dataset's suitability for causal analysis.
|
44 |
+
"""
|
45 |
+
if not data_json:
|
46 |
+
return "No dataset provided for compatibility assessment."
|
47 |
+
|
48 |
+
try:
|
49 |
+
df = pd.DataFrame(data_json)
|
50 |
+
num_rows, num_cols = df.shape
|
51 |
+
numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
|
52 |
+
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
|
53 |
+
missing_values = df.isnull().sum().sum()
|
54 |
+
|
55 |
+
assessment = [
|
56 |
+
f"Dataset has {num_rows} rows and {num_cols} columns.",
|
57 |
+
f"Numeric columns ({len(numeric_cols)}): {', '.join(numeric_cols) if len(numeric_cols) > 0 else 'None'}.",
|
58 |
+
f"Categorical columns ({len(categorical_cols)}): {', '.join(categorical_cols) if len(categorical_cols) > 0 else 'None'}.",
|
59 |
+
f"Missing values: {missing_values}."
|
60 |
+
]
|
61 |
+
|
62 |
+
# Causal compatibility insights
|
63 |
+
if num_cols < 3:
|
64 |
+
assessment.append("Warning: Dataset has fewer than 3 columns, which may limit causal analysis (e.g., no room for treatment, outcome, and confounders).")
|
65 |
+
if len(numeric_cols) == 0:
|
66 |
+
assessment.append("Warning: No numeric columns detected. Causal inference often requires numeric variables for treatment or outcome.")
|
67 |
+
if missing_values > 0:
|
68 |
+
assessment.append("Note: Missing values detected. Preprocessing (e.g., imputation) may be needed for accurate causal analysis.")
|
69 |
+
if len(numeric_cols) >= 2 and num_rows > 100:
|
70 |
+
assessment.append("Positive: Dataset has multiple numeric columns and sufficient rows, suitable for causal inference with proper preprocessing.")
|
71 |
+
else:
|
72 |
+
assessment.append("Note: Ensure at least two numeric columns (e.g., treatment and outcome) and sufficient data points for robust causal analysis.")
|
73 |
+
|
74 |
+
return "\n".join(assessment)
|
75 |
+
except Exception as e:
|
76 |
+
print(f"Error in assess_causal_compatibility: {e}")
|
77 |
+
return "Unable to assess dataset compatibility due to processing error."
|
78 |
+
|
79 |
+
# Define tools using LangChain's @tool decorator
|
80 |
+
@tool
|
81 |
+
def get_dataset_info() -> dict:
|
82 |
+
"""
|
83 |
+
Provides summary information and causal compatibility assessment for the currently loaded dataset.
|
84 |
+
The dataset is provided by the backend session context.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
Dictionary containing the dataset summary and compatibility assessment.
|
88 |
+
"""
|
89 |
+
return {"summary": "Dataset will be provided by session context"}
|
90 |
+
|
91 |
+
@tool
|
92 |
+
def get_causal_graph_info() -> dict:
|
93 |
+
"""
|
94 |
+
Provides summary information about the currently discovered causal graph.
|
95 |
+
The graph data is provided by the backend session context.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Dictionary containing the graph summary.
|
99 |
+
"""
|
100 |
+
return {"summary": "Graph data will be provided by session context"}
|
101 |
+
|
102 |
+
# Bind tools to the model
|
103 |
+
tools = [get_dataset_info, get_causal_graph_info]
|
104 |
+
if model:
|
105 |
+
model_with_tools = model.bind_tools(tools)
|
106 |
+
|
107 |
+
def get_chatbot_response(user_message: str, session_context: dict) -> str:
|
108 |
+
"""
|
109 |
+
Gets a response from the Groq chatbot, handling tool calls.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
user_message: The message from the user.
|
113 |
+
session_context: Dictionary containing current session data
|
114 |
+
(e.g., processed_data, causal_graph_adj, causal_graph_nodes).
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
The chatbot's response message.
|
118 |
+
"""
|
119 |
+
if model is None:
|
120 |
+
return "Chatbot is not configured correctly. Please check Groq API key."
|
121 |
+
|
122 |
+
try:
|
123 |
+
# Create a prompt template to guide the model's behavior
|
124 |
+
prompt = ChatPromptTemplate.from_messages([
|
125 |
+
("system", """You are CausalBox Assistant, an AI that helps users analyze datasets and causal graphs.
|
126 |
+
Use the provided tools to access dataset or graph information. Do NOT generate or guess parameters for tool calls; the backend will provide all necessary data (e.g., dataset or graph details).
|
127 |
+
For dataset queries (e.g., "read the dataset", "dataset compatibility"), call `get_dataset_info` without arguments.
|
128 |
+
For graph queries (e.g., "describe the causal graph"), call `get_causal_graph_info` without arguments.
|
129 |
+
For other questions (e.g., "what is a confounder?"), respond directly with clear, accurate explanations.
|
130 |
+
|
131 |
+
When you receive tool results, provide a comprehensive analysis and explanation to help the user understand their data and causal analysis possibilities.
|
132 |
+
|
133 |
+
Examples:
|
134 |
+
- User: "Tell me about the dataset" -> Call `get_dataset_info`.
|
135 |
+
- User: "Check dataset compatibility for causal analysis" -> Call `get_dataset_info`.
|
136 |
+
- User: "Describe the causal graph" -> Call `get_causal_graph_info`.
|
137 |
+
- User: "What is a confounder?" -> Respond: "A confounder is a variable that influences both the treatment and outcome, causing a spurious association."
|
138 |
+
"""),
|
139 |
+
("human", "{user_message}")
|
140 |
+
])
|
141 |
+
|
142 |
+
# Chain the prompt with the model
|
143 |
+
chain = prompt | model_with_tools
|
144 |
+
|
145 |
+
# Log the user message and session context
|
146 |
+
print(f"Processing user message: {user_message}")
|
147 |
+
print(f"Session context keys: {list(session_context.keys())}")
|
148 |
+
|
149 |
+
# Invoke the chain with the user message
|
150 |
+
response = chain.invoke({"user_message": user_message})
|
151 |
+
print(f"Model response: {response}")
|
152 |
+
|
153 |
+
# Handle tool calls if present
|
154 |
+
if response.tool_calls:
|
155 |
+
tool_call = response.tool_calls[0]
|
156 |
+
function_name = tool_call["name"]
|
157 |
+
function_args = tool_call["args"]
|
158 |
+
|
159 |
+
print(f"Chatbot calling tool: {function_name} with args: {function_args}")
|
160 |
+
|
161 |
+
# Map session context to tool arguments
|
162 |
+
tool_output = {}
|
163 |
+
if function_name == "get_dataset_info":
|
164 |
+
data_json = session_context.get("processed_data", [])
|
165 |
+
if not isinstance(data_json, list) or not data_json:
|
166 |
+
print(f"Invalid or empty data_json: {data_json}")
|
167 |
+
return "Error: No valid dataset available."
|
168 |
+
tool_output = get_dataset_info.invoke({})
|
169 |
+
tool_output["summary"] = summarize_dataframe_for_chatbot(data_json)
|
170 |
+
tool_output["causal_compatibility"] = assess_causal_compatibility(data_json)
|
171 |
+
elif function_name == "get_causal_graph_info":
|
172 |
+
graph_adj = session_context.get("causal_graph_adj", [])
|
173 |
+
nodes = session_context.get("causal_graph_nodes", [])
|
174 |
+
if not graph_adj or not nodes:
|
175 |
+
print("No causal graph data available")
|
176 |
+
return "Error: No causal graph available."
|
177 |
+
tool_output = get_causal_graph_info.invoke({})
|
178 |
+
tool_output["summary"] = get_graph_summary_for_chatbot(graph_adj, nodes)
|
179 |
+
else:
|
180 |
+
print(f"Unknown tool: {function_name}")
|
181 |
+
return f"Error: Unknown tool {function_name}."
|
182 |
+
|
183 |
+
print(f"Tool output: {tool_output}")
|
184 |
+
|
185 |
+
# Create the tool output text
|
186 |
+
output_text = tool_output["summary"]
|
187 |
+
if tool_output.get("causal_compatibility"):
|
188 |
+
output_text += "\n\nCausal Compatibility Assessment:\n" + tool_output["causal_compatibility"]
|
189 |
+
|
190 |
+
# Create messages for the final response - FIXED VERSION
|
191 |
+
messages = [
|
192 |
+
HumanMessage(content=user_message),
|
193 |
+
AIMessage(content="", tool_calls=[tool_call]),
|
194 |
+
ToolMessage(content=output_text, tool_call_id=tool_call["id"])
|
195 |
+
]
|
196 |
+
|
197 |
+
# Create a follow-up prompt to ensure the model provides a comprehensive response
|
198 |
+
follow_up_prompt = ChatPromptTemplate.from_messages([
|
199 |
+
("system", """You are CausalBox Assistant. Based on the tool results, provide a comprehensive, helpful response to the user's question.
|
200 |
+
Explain the dataset characteristics, causal compatibility, and provide actionable insights for causal analysis.
|
201 |
+
Be specific about what the data shows and what causal analysis approaches would be suitable.
|
202 |
+
Always provide a complete response, not just acknowledgment."""),
|
203 |
+
("human", "{original_question}"),
|
204 |
+
("assistant", "I'll analyze the dataset information for you."),
|
205 |
+
("human", "Here's the dataset analysis: {tool_results}\n\nPlease provide a comprehensive explanation of this data and its suitability for causal analysis.")
|
206 |
+
])
|
207 |
+
|
208 |
+
# Get final response from the model with explicit prompting
|
209 |
+
print("Invoking model with tool response messages")
|
210 |
+
try:
|
211 |
+
final_chain = follow_up_prompt | model
|
212 |
+
final_response = final_chain.invoke({
|
213 |
+
"original_question": user_message,
|
214 |
+
"tool_results": output_text
|
215 |
+
})
|
216 |
+
print(f"Final response content: {final_response.content}")
|
217 |
+
|
218 |
+
if final_response.content and final_response.content.strip():
|
219 |
+
return final_response.content
|
220 |
+
else:
|
221 |
+
# Fallback response if model still returns empty
|
222 |
+
return create_fallback_response(output_text, user_message)
|
223 |
+
|
224 |
+
except Exception as e:
|
225 |
+
print(f"Error in final response generation: {e}")
|
226 |
+
return create_fallback_response(output_text, user_message)
|
227 |
+
|
228 |
+
else:
|
229 |
+
print("No tool calls, returning direct response")
|
230 |
+
if response.content and response.content.strip():
|
231 |
+
return response.content
|
232 |
+
else:
|
233 |
+
return "I'm ready to help you with causal analysis. Please ask me about your dataset, causal graphs, or any causal inference concepts you'd like to understand."
|
234 |
+
|
235 |
+
except Exception as e:
|
236 |
+
print(f"Error communicating with Groq: {e}")
|
237 |
+
return f"Sorry, I'm having trouble processing your request: {str(e)}"
|
238 |
+
|
239 |
+
def create_fallback_response(tool_output: str, user_message: str) -> str:
|
240 |
+
"""
|
241 |
+
Creates a fallback response when the model returns empty content.
|
242 |
+
"""
|
243 |
+
response_parts = ["Based on your dataset analysis:\n"]
|
244 |
+
|
245 |
+
if "Dataset Summary:" in tool_output:
|
246 |
+
response_parts.append("📊 **Dataset Overview:**")
|
247 |
+
summary_part = tool_output.split("Dataset Summary:")[1].split("Causal Compatibility Assessment:")[0]
|
248 |
+
response_parts.append(summary_part.strip())
|
249 |
+
response_parts.append("")
|
250 |
+
|
251 |
+
if "Causal Compatibility Assessment:" in tool_output:
|
252 |
+
response_parts.append("🔍 **Causal Analysis Compatibility:**")
|
253 |
+
compatibility_part = tool_output.split("Causal Compatibility Assessment:")[1]
|
254 |
+
response_parts.append(compatibility_part.strip())
|
255 |
+
response_parts.append("")
|
256 |
+
|
257 |
+
# Add specific insights based on the data
|
258 |
+
if "FinalExamScore" in tool_output:
|
259 |
+
response_parts.append("💡 **Key Insights for Causal Analysis:**")
|
260 |
+
response_parts.append("- Your dataset appears to be education-related with variables like FinalExamScore, StudyHours, and TuitionHours")
|
261 |
+
response_parts.append("- This is excellent for causal analysis as you can explore questions like:")
|
262 |
+
response_parts.append(" • Does increasing study hours causally improve exam scores?")
|
263 |
+
response_parts.append(" • What's the causal effect of tutoring (TuitionHours) on performance?")
|
264 |
+
response_parts.append(" • How does parental education influence student outcomes?")
|
265 |
+
response_parts.append("")
|
266 |
+
response_parts.append("🚀 **Next Steps:**")
|
267 |
+
response_parts.append("- Consider identifying your treatment variable (e.g., TuitionHours)")
|
268 |
+
response_parts.append("- Define your outcome variable (likely FinalExamScore)")
|
269 |
+
response_parts.append("- Identify potential confounders (ParentalEducation, SchoolType)")
|
270 |
+
|
271 |
+
return "\n".join(response_parts)
|
utils/do_calculus.py
CHANGED
@@ -1,52 +1,52 @@
|
|
1 |
-
# utils/do_calculus.py
|
2 |
-
import pandas as pd
|
3 |
-
import numpy as np
|
4 |
-
import networkx as nx
|
5 |
-
|
6 |
-
class DoCalculus:
|
7 |
-
def __init__(self, graph):
|
8 |
-
self.graph = graph
|
9 |
-
|
10 |
-
def intervene(self, data, intervention_var, intervention_value):
|
11 |
-
"""
|
12 |
-
Simulate do(X=x) intervention on a variable.
|
13 |
-
Returns intervened DataFrame.
|
14 |
-
This is a simplified implementation.
|
15 |
-
"""
|
16 |
-
intervened_data = data.copy()
|
17 |
-
|
18 |
-
# Direct intervention: set the value
|
19 |
-
intervened_data[intervention_var] = intervention_value
|
20 |
-
|
21 |
-
# Propagate effects (simplified linear model) - needs graph
|
22 |
-
# For a true do-calculus, you'd prune the graph and re-estimate based on parents
|
23 |
-
# For demonstration, this still uses a simplified propagation.
|
24 |
-
try:
|
25 |
-
# Ensure graph is connected and topological sort is possible
|
26 |
-
if self.graph and not nx.is_directed_acyclic_graph(self.graph):
|
27 |
-
print("Warning: Graph is not a DAG. Topological sort may fail or be incorrect for do-calculus.")
|
28 |
-
|
29 |
-
# This simplified propagation is a conceptual placeholder
|
30 |
-
for node in nx.topological_sort(self.graph):
|
31 |
-
if node == intervention_var:
|
32 |
-
continue # Do not propagate back to the intervened variable
|
33 |
-
|
34 |
-
parents = list(self.graph.predecessors(node))
|
35 |
-
if parents:
|
36 |
-
# Very simplified linear model to show propagation
|
37 |
-
# In reality, you'd use learned coefficients or structural equations
|
38 |
-
combined_effect = np.zeros(len(intervened_data))
|
39 |
-
for p in parents:
|
40 |
-
if p in intervened_data.columns:
|
41 |
-
# Use a fixed random coefficient for demonstration
|
42 |
-
coeff = 0.5
|
43 |
-
combined_effect += intervened_data[p].to_numpy() * coeff
|
44 |
-
|
45 |
-
# Add a small random noise to simulate uncertainty
|
46 |
-
intervened_data[node] += combined_effect + np.random.normal(0, 0.1, len(intervened_data))
|
47 |
-
except Exception as e:
|
48 |
-
print(f"Could not perform full propagation due to graph issues or simplification: {e}")
|
49 |
-
# Fallback to direct intervention only if graph logic fails
|
50 |
-
pass # The direct intervention `intervened_data[intervention_var] = intervention_value` is already applied
|
51 |
-
|
52 |
return intervened_data
|
|
|
1 |
+
# utils/do_calculus.py
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import networkx as nx
|
5 |
+
|
6 |
+
class DoCalculus:
|
7 |
+
def __init__(self, graph):
|
8 |
+
self.graph = graph
|
9 |
+
|
10 |
+
def intervene(self, data, intervention_var, intervention_value):
|
11 |
+
"""
|
12 |
+
Simulate do(X=x) intervention on a variable.
|
13 |
+
Returns intervened DataFrame.
|
14 |
+
This is a simplified implementation.
|
15 |
+
"""
|
16 |
+
intervened_data = data.copy()
|
17 |
+
|
18 |
+
# Direct intervention: set the value
|
19 |
+
intervened_data[intervention_var] = intervention_value
|
20 |
+
|
21 |
+
# Propagate effects (simplified linear model) - needs graph
|
22 |
+
# For a true do-calculus, you'd prune the graph and re-estimate based on parents
|
23 |
+
# For demonstration, this still uses a simplified propagation.
|
24 |
+
try:
|
25 |
+
# Ensure graph is connected and topological sort is possible
|
26 |
+
if self.graph and not nx.is_directed_acyclic_graph(self.graph):
|
27 |
+
print("Warning: Graph is not a DAG. Topological sort may fail or be incorrect for do-calculus.")
|
28 |
+
|
29 |
+
# This simplified propagation is a conceptual placeholder
|
30 |
+
for node in nx.topological_sort(self.graph):
|
31 |
+
if node == intervention_var:
|
32 |
+
continue # Do not propagate back to the intervened variable
|
33 |
+
|
34 |
+
parents = list(self.graph.predecessors(node))
|
35 |
+
if parents:
|
36 |
+
# Very simplified linear model to show propagation
|
37 |
+
# In reality, you'd use learned coefficients or structural equations
|
38 |
+
combined_effect = np.zeros(len(intervened_data))
|
39 |
+
for p in parents:
|
40 |
+
if p in intervened_data.columns:
|
41 |
+
# Use a fixed random coefficient for demonstration
|
42 |
+
coeff = 0.5
|
43 |
+
combined_effect += intervened_data[p].to_numpy() * coeff
|
44 |
+
|
45 |
+
# Add a small random noise to simulate uncertainty
|
46 |
+
intervened_data[node] += combined_effect + np.random.normal(0, 0.1, len(intervened_data))
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Could not perform full propagation due to graph issues or simplification: {e}")
|
49 |
+
# Fallback to direct intervention only if graph logic fails
|
50 |
+
pass # The direct intervention `intervened_data[intervention_var] = intervention_value` is already applied
|
51 |
+
|
52 |
return intervened_data
|
utils/graph_utils.py
CHANGED
@@ -1,60 +1,107 @@
|
|
1 |
-
# utils/graph_utils.py
|
2 |
-
import networkx as nx
|
3 |
-
import plotly.graph_objects as go
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
text
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/graph_utils.py
|
2 |
+
import networkx as nx
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def visualize_graph(graph):
|
7 |
+
"""
|
8 |
+
Visualize a causal graph using Plotly.
|
9 |
+
Returns Plotly figure as JSON.
|
10 |
+
"""
|
11 |
+
# Use a fixed seed for layout reproducibility (optional)
|
12 |
+
pos = nx.spring_layout(graph, seed=42)
|
13 |
+
|
14 |
+
edge_x, edge_y = [], []
|
15 |
+
for edge in graph.edges():
|
16 |
+
x0, y0 = pos[edge[0]]
|
17 |
+
x1, y1 = pos[edge[1]]
|
18 |
+
edge_x.extend([x0, x1, None])
|
19 |
+
edge_y.extend([y0, y1, None])
|
20 |
+
|
21 |
+
edge_trace = go.Scatter(
|
22 |
+
x=edge_x, y=edge_y,
|
23 |
+
line=dict(width=1, color='#888'),
|
24 |
+
mode='lines',
|
25 |
+
hoverinfo='none'
|
26 |
+
)
|
27 |
+
|
28 |
+
node_x, node_y = [], []
|
29 |
+
for node in graph.nodes():
|
30 |
+
x, y = pos[node]
|
31 |
+
node_x.append(x)
|
32 |
+
node_y.append(y)
|
33 |
+
|
34 |
+
node_trace = go.Scatter(
|
35 |
+
x=node_x, y=node_y,
|
36 |
+
mode='markers+text',
|
37 |
+
text=list(graph.nodes()),
|
38 |
+
textposition='bottom center',
|
39 |
+
marker=dict(size=15, color='lightblue', line=dict(width=2, color='DarkSlateGrey')),
|
40 |
+
hoverinfo='text'
|
41 |
+
)
|
42 |
+
|
43 |
+
fig = go.Figure(
|
44 |
+
data=[edge_trace, node_trace],
|
45 |
+
layout=go.Layout(
|
46 |
+
showlegend=False,
|
47 |
+
hovermode='closest',
|
48 |
+
margin=dict(b=20, l=5, r=5, t=40),
|
49 |
+
annotations=[dict(
|
50 |
+
text="Python Causal Graph",
|
51 |
+
showarrow=False,
|
52 |
+
xref="paper", yref="paper",
|
53 |
+
x=0.005, y= -0.002,
|
54 |
+
font=dict(size=14, color="lightgray")
|
55 |
+
)],
|
56 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
57 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
58 |
+
title=dict(text="Causal Graph Visualization", font=dict(size=16)) # Corrected line
|
59 |
+
)
|
60 |
+
)
|
61 |
+
return fig.to_json()
|
62 |
+
|
63 |
+
def get_graph_summary_for_chatbot(graph_adj, nodes):
|
64 |
+
"""
|
65 |
+
Generates a text summary of the causal graph for the chatbot.
|
66 |
+
"""
|
67 |
+
if not graph_adj or not nodes:
|
68 |
+
return "No causal graph discovered yet."
|
69 |
+
|
70 |
+
adj_matrix = np.array(graph_adj)
|
71 |
+
G = nx.DiGraph(adj_matrix)
|
72 |
+
|
73 |
+
# Relabel nodes with actual names
|
74 |
+
mapping = {i: node_name for i, node_name in enumerate(nodes)}
|
75 |
+
G = nx.relabel_nodes(G, mapping)
|
76 |
+
|
77 |
+
num_nodes = G.number_of_nodes()
|
78 |
+
num_edges = G.number_of_edges()
|
79 |
+
|
80 |
+
summary = (
|
81 |
+
f"The causal graph has {num_nodes} variables (nodes) and {num_edges} causal relationships (directed edges).\n"
|
82 |
+
"The variables are: " + ", ".join(nodes) + ".\n"
|
83 |
+
)
|
84 |
+
|
85 |
+
# Add some basic structural info
|
86 |
+
if nx.is_directed_acyclic_graph(G):
|
87 |
+
summary += "The graph is a Directed Acyclic Graph (DAG), which is typical for causal models.\n"
|
88 |
+
else:
|
89 |
+
summary += "The graph contains cycles, which might indicate feedback loops or issues with the discovery algorithm for a DAG model.\n"
|
90 |
+
|
91 |
+
# Smallest graphs: list all edges
|
92 |
+
if num_edges > 0 and num_edges < 10: # Avoid listing too many edges for large graphs
|
93 |
+
edge_list = [f"{u} -> {v}" for u, v in G.edges()]
|
94 |
+
summary += "The discovered relationships are: " + ", ".join(edge_list) + ".\n"
|
95 |
+
elif num_edges >= 10:
|
96 |
+
summary += "There are many edges; you can ask for specific relationships (e.g., 'What are the direct causes of X?').\n"
|
97 |
+
|
98 |
+
# Identify source and sink nodes (if any)
|
99 |
+
source_nodes = [n for n, d in G.in_degree() if d == 0]
|
100 |
+
sink_nodes = [n for n, d in G.out_degree() if d == 0]
|
101 |
+
|
102 |
+
if source_nodes:
|
103 |
+
summary += f"Variables with no known causes (source nodes): {', '.join(source_nodes)}.\n"
|
104 |
+
if sink_nodes:
|
105 |
+
summary += f"Variables with no known effects (sink nodes): {', '.join(sink_nodes)}.\n"
|
106 |
+
|
107 |
+
return summary
|
utils/prediction_models.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/prediction_models.py
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def train_predict_random_forest(data_list, target_col, feature_cols, prediction_type='regression'):
|
9 |
+
"""
|
10 |
+
Trains a Random Forest model and performs prediction/evaluation.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
data_list (list of dict): List of dictionaries representing the dataset.
|
14 |
+
target_col (str): Name of the target variable.
|
15 |
+
feature_cols (list): List of names of feature variables.
|
16 |
+
prediction_type (str): 'regression' or 'classification'.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
dict: A dictionary containing model results (metrics, predictions, feature importances).
|
20 |
+
"""
|
21 |
+
df = pd.DataFrame(data_list)
|
22 |
+
|
23 |
+
if not all(col in df.columns for col in feature_cols + [target_col]):
|
24 |
+
missing_cols = [col for col in feature_cols + [target_col] if col not in df.columns]
|
25 |
+
raise ValueError(f"Missing columns in data: {missing_cols}")
|
26 |
+
|
27 |
+
X = df[feature_cols]
|
28 |
+
y = df[target_col]
|
29 |
+
|
30 |
+
# Handle categorical features if any
|
31 |
+
X = pd.get_dummies(X, drop_first=True) # One-hot encode categorical features
|
32 |
+
|
33 |
+
# Split data for robust evaluation
|
34 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
35 |
+
|
36 |
+
results = {}
|
37 |
+
|
38 |
+
if prediction_type == 'regression':
|
39 |
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
40 |
+
model.fit(X_train, y_train)
|
41 |
+
y_pred = model.predict(X_test)
|
42 |
+
|
43 |
+
results['model_type'] = 'Regression'
|
44 |
+
results['r2_score'] = r2_score(y_test, y_pred)
|
45 |
+
results['mean_squared_error'] = mean_squared_error(y_test, y_pred)
|
46 |
+
results['root_mean_squared_error'] = np.sqrt(mean_squared_error(y_test, y_pred))
|
47 |
+
results['actual_vs_predicted'] = pd.DataFrame({'Actual': y_test, 'Predicted': y_pred}).to_dict(orient='list')
|
48 |
+
|
49 |
+
elif prediction_type == 'classification':
|
50 |
+
# Ensure target variable is suitable for classification (e.g., integer/categorical)
|
51 |
+
# You might need more robust handling for different target types here
|
52 |
+
if y.dtype == 'object' or y.dtype.name == 'category':
|
53 |
+
y_train = y_train.astype('category').cat.codes
|
54 |
+
y_test = y_test.astype('category').cat.codes
|
55 |
+
y_unique_labels = df[target_col].astype('category').cat.categories.tolist()
|
56 |
+
results['class_labels'] = y_unique_labels
|
57 |
+
else:
|
58 |
+
y_unique_labels = sorted(y.unique().tolist())
|
59 |
+
results['class_labels'] = y_unique_labels
|
60 |
+
|
61 |
+
|
62 |
+
model = RandomForestClassifier(n_estimators=100, random_state=42)
|
63 |
+
model.fit(X_train, y_train)
|
64 |
+
y_pred = model.predict(X_test)
|
65 |
+
|
66 |
+
results['model_type'] = 'Classification'
|
67 |
+
results['accuracy'] = accuracy_score(y_test, y_pred)
|
68 |
+
|
69 |
+
# Precision, Recall, F1-score - use 'weighted' average for multi-class
|
70 |
+
precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='weighted', zero_division=0)
|
71 |
+
results['precision'] = precision
|
72 |
+
results['recall'] = recall
|
73 |
+
results['f1_score'] = f1
|
74 |
+
|
75 |
+
results['confusion_matrix'] = confusion_matrix(y_test, y_pred).tolist()
|
76 |
+
results['classification_report'] = classification_report(y_test, y_pred, output_dict=True, zero_division=0)
|
77 |
+
|
78 |
+
else:
|
79 |
+
raise ValueError("prediction_type must be 'regression' or 'classification'")
|
80 |
+
|
81 |
+
# Feature Importance (common for both)
|
82 |
+
if hasattr(model, 'feature_importances_'):
|
83 |
+
feature_importances = pd.Series(model.feature_importances_, index=X.columns).sort_values(ascending=False)
|
84 |
+
results['feature_importances'] = feature_importances.to_dict()
|
85 |
+
|
86 |
+
return results
|
utils/preprocessor.py
CHANGED
@@ -1,57 +1,88 @@
|
|
1 |
-
# utils/preprocessor.py
|
2 |
-
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
3 |
-
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
-
import logging
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
def
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
logger.info("
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/preprocessor.py
|
2 |
+
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import logging
|
6 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
7 |
+
from sklearn.impute import SimpleImputer
|
8 |
+
from sklearn.compose import ColumnTransformer
|
9 |
+
from sklearn.pipeline import Pipeline
|
10 |
+
|
11 |
+
# Set up logging
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
class DataPreprocessor:
|
16 |
+
def __init__(self):
|
17 |
+
self.scaler = StandardScaler()
|
18 |
+
self.label_encoders = {}
|
19 |
+
|
20 |
+
def preprocess(self, df):
|
21 |
+
"""
|
22 |
+
Preprocess DataFrame: handle missing values, encode categorical variables, scale numerical variables.
|
23 |
+
"""
|
24 |
+
try:
|
25 |
+
logger.info(f"Input DataFrame shape: {df.shape}, columns: {list(df.columns)}")
|
26 |
+
df_processed = df.copy()
|
27 |
+
|
28 |
+
# Handle missing values
|
29 |
+
logger.info("Handling missing values...")
|
30 |
+
for col in df_processed.columns:
|
31 |
+
if df_processed[col].isnull().any():
|
32 |
+
if pd.api.types.is_numeric_dtype(df_processed[col]):
|
33 |
+
df_processed[col] = df_processed[col].fillna(df_processed[col].mean())
|
34 |
+
logger.info(f"Filled numeric missing values in '{col}' with mean.")
|
35 |
+
else:
|
36 |
+
df_processed[col] = df_processed[col].fillna(df_processed[col].mode()[0])
|
37 |
+
logger.info(f"Filled categorical missing values in '{col}' with mode.")
|
38 |
+
|
39 |
+
# Encode categorical variables
|
40 |
+
logger.info("Encoding categorical variables...")
|
41 |
+
for col in df_processed.select_dtypes(include=['object', 'category']).columns:
|
42 |
+
logger.info(f"Encoding column: {col}")
|
43 |
+
self.label_encoders[col] = LabelEncoder()
|
44 |
+
df_processed[col] = self.label_encoders[col].fit_transform(df_processed[col])
|
45 |
+
|
46 |
+
# Scale numerical variables
|
47 |
+
logger.info("Scaling numerical variables...")
|
48 |
+
numeric_cols = df_processed.select_dtypes(include=[np.number]).columns
|
49 |
+
if len(numeric_cols) > 0:
|
50 |
+
# Exclude columns that are now effectively categorical (post-label encoding)
|
51 |
+
# This is a heuristic; ideally, identify original numeric columns.
|
52 |
+
cols_to_scale = [col for col in numeric_cols if col not in self.label_encoders]
|
53 |
+
if cols_to_scale:
|
54 |
+
df_processed[cols_to_scale] = self.scaler.fit_transform(df_processed[cols_to_scale])
|
55 |
+
logger.info(f"Scaled numeric columns: {cols_to_scale}")
|
56 |
+
|
57 |
+
logger.info(f"Preprocessed DataFrame shape: {df_processed.shape}")
|
58 |
+
return df_processed
|
59 |
+
except Exception as e:
|
60 |
+
logger.exception(f"Error preprocessing data: {str(e)}")
|
61 |
+
raise
|
62 |
+
|
63 |
+
def summarize_dataframe_for_chatbot(data_list):
|
64 |
+
"""
|
65 |
+
Generates a test summary of the DataFrame for chatbot interaction."""
|
66 |
+
if not data_list:
|
67 |
+
return "No data loaded."
|
68 |
+
df = pd.DataFrame(data_list)
|
69 |
+
nums_rows, num_cols = df.shape
|
70 |
+
|
71 |
+
col_info = []
|
72 |
+
for col in df.columns:
|
73 |
+
dtype = df[col].dtype
|
74 |
+
unique_vals = df[col].nunique()
|
75 |
+
missing_count = df[col].isnull().sum()
|
76 |
+
|
77 |
+
info = f"-{col} (Type:{dtype}"
|
78 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
79 |
+
info +=f", Min:{df[col].min():.2f}, Max:{df[col].max():.2f}"
|
80 |
+
else:
|
81 |
+
info += f", Unique:{unique_vals}"
|
82 |
+
|
83 |
+
if missing_count > 0:
|
84 |
+
info += f", Missing:{missing_count}"
|
85 |
+
info += ")"
|
86 |
+
col_info.append(info)
|
87 |
+
summary = (f"Dataset Summary:\n- Rows: {nums_rows}, Columns: {num_cols}\nColumns:\n" + "\n".join(col_info))
|
88 |
+
return summary
|
utils/time_series_causal.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/time_series_causal.py
|
2 |
+
import pandas as pd
|
3 |
+
from statsmodels.tsa.stattools import grangercausalitytests
|
4 |
+
|
5 |
+
def perform_granger_causality(data_list, timestamp_col, variables_to_analyze, max_lags=1):
|
6 |
+
"""
|
7 |
+
Performs pairwise Granger Causality tests on the given time-series data.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
data_list (list of dict): List of dictionaries representing the dataset.
|
11 |
+
timestamp_col (str): Name of the timestamp column.
|
12 |
+
variables_to_analyze (list): List of names of variables to test for causality.
|
13 |
+
max_lags (int): The maximum number of lags to use for the Granger causality test.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
list: A list of dictionaries, each describing a causal relationship found.
|
17 |
+
"""
|
18 |
+
df = pd.DataFrame(data_list)
|
19 |
+
|
20 |
+
if timestamp_col not in df.columns:
|
21 |
+
raise ValueError(f"Timestamp column '{timestamp_col}' not found in data.")
|
22 |
+
|
23 |
+
# Ensure timestamp column is datetime and set as index
|
24 |
+
try:
|
25 |
+
df[timestamp_col] = pd.to_datetime(df[timestamp_col])
|
26 |
+
df = df.set_index(timestamp_col).sort_index()
|
27 |
+
except Exception as e:
|
28 |
+
raise ValueError(f"Could not convert timestamp column '{timestamp_col}' to datetime: {e}")
|
29 |
+
|
30 |
+
# Ensure all variables to analyze are numeric
|
31 |
+
for col in variables_to_analyze:
|
32 |
+
if not pd.api.types.is_numeric_dtype(df[col]):
|
33 |
+
raise ValueError(f"Variable '{col}' is not numeric. Granger Causality requires numeric variables.")
|
34 |
+
if df[col].isnull().any():
|
35 |
+
# Handle NaNs: Granger Causality tests require no NaN values.
|
36 |
+
# You might choose to drop rows with NaNs or impute.
|
37 |
+
# For simplicity, here we'll raise an error or drop them.
|
38 |
+
# print(f"Warning: Variable '{col}' contains NaN values. Rows with NaNs will be dropped.")
|
39 |
+
df = df.dropna(subset=[col])
|
40 |
+
|
41 |
+
|
42 |
+
# Select only the relevant columns
|
43 |
+
df_selected = df[variables_to_analyze]
|
44 |
+
|
45 |
+
# Granger Causality requires stationarity in theory.
|
46 |
+
# While statsmodels can run on non-stationary data, results should be interpreted cautiously.
|
47 |
+
# You might want to add differencing logic here (e.g., df.diff().dropna())
|
48 |
+
# or a warning for the user.
|
49 |
+
# For now, we proceed directly.
|
50 |
+
|
51 |
+
causal_results = []
|
52 |
+
|
53 |
+
# Iterate through all unique pairs of variables
|
54 |
+
for i in range(len(variables_to_analyze)):
|
55 |
+
for j in range(len(variables_to_analyze)):
|
56 |
+
if i == j:
|
57 |
+
continue # Skip self-causation tests
|
58 |
+
|
59 |
+
cause_var = variables_to_analyze[i]
|
60 |
+
effect_var = variables_to_analyze[j]
|
61 |
+
|
62 |
+
# Prepare data for grangercausalitytests: [effect_var, cause_var]
|
63 |
+
# grangercausalitytests takes a DataFrame where the first column is the dependent variable (effect)
|
64 |
+
# and the second column is the independent variable (cause)
|
65 |
+
data_for_test = df_selected[[effect_var, cause_var]]
|
66 |
+
|
67 |
+
if data_for_test.empty or len(data_for_test) <= max_lags:
|
68 |
+
# Not enough data points to perform test with specified lags
|
69 |
+
# This can happen if NaNs were dropped or dataset is too small
|
70 |
+
continue
|
71 |
+
|
72 |
+
try:
|
73 |
+
# Perform Granger Causality test
|
74 |
+
# The output is a dictionary. The key 'ssr_ftest' (or 'params_ftest')
|
75 |
+
# usually contains the p-value.
|
76 |
+
test_result = grangercausalitytests(data_for_test, max_lags, verbose=False)
|
77 |
+
|
78 |
+
# Extract p-value for the optimal lag or the test that interests you
|
79 |
+
# Commonly, F-test p-value for the last lag tested is used
|
80 |
+
# test_result is a dictionary where keys are lag numbers
|
81 |
+
# Each lag has a tuple of (test_statistics, p_values).
|
82 |
+
# (F-test, Chi2-test, LR-test, SSR-test) -> [statistic, p-value, df_denom, df_num]
|
83 |
+
|
84 |
+
# Let's consider the F-test for the last lag as a general indicator
|
85 |
+
last_lag_p_value = test_result[max_lags][0]['ssr_ftest'][1] # F-test p-value
|
86 |
+
|
87 |
+
causal_results.append({
|
88 |
+
"cause": cause_var,
|
89 |
+
"effect": effect_var,
|
90 |
+
"p_value": last_lag_p_value,
|
91 |
+
"test_type": "Granger Causality (F-test)",
|
92 |
+
"max_lags": max_lags
|
93 |
+
})
|
94 |
+
except ValueError as ve:
|
95 |
+
# Handle cases where the test cannot be performed (e.g., singular matrix)
|
96 |
+
print(f"Could not perform Granger Causality for {cause_var} -> {effect_var} with max_lags={max_lags}: {ve}")
|
97 |
+
continue # Skip this pair
|
98 |
+
except Exception as e:
|
99 |
+
print(f"An unexpected error occurred for {cause_var} -> {effect_var}: {e}")
|
100 |
+
continue
|
101 |
+
|
102 |
+
return causal_results
|
utils/treatment_effects.py
CHANGED
@@ -1,63 +1,63 @@
|
|
1 |
-
# utils/treatment_effects.py
|
2 |
-
from sklearn.linear_model import LinearRegression, LogisticRegression
|
3 |
-
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
-
# For matching-based methods, you might need libraries like dowhy or causalml
|
6 |
-
# import statsmodels.api as sm # Example for regression diagnostics
|
7 |
-
|
8 |
-
class TreatmentEffectAlgorithms:
|
9 |
-
def linear_regression_ate(self, df, treatment_col, outcome_col, covariates):
|
10 |
-
"""
|
11 |
-
Estimate ATE using linear regression.
|
12 |
-
"""
|
13 |
-
X = df[covariates + [treatment_col]]
|
14 |
-
y = df[outcome_col]
|
15 |
-
model = LinearRegression()
|
16 |
-
model.fit(X, y)
|
17 |
-
ate = model.coef_[-1] # Coefficient of treatment_col
|
18 |
-
return float(ate)
|
19 |
-
|
20 |
-
def propensity_score_matching(self, df, treatment_col, outcome_col, covariates):
|
21 |
-
"""
|
22 |
-
Placeholder for Propensity Score Matching.
|
23 |
-
You would implement or integrate a matching algorithm here.
|
24 |
-
"""
|
25 |
-
print("Propensity Score Matching is a placeholder. Returning a dummy ATE.")
|
26 |
-
# Simplified: Estimate propensity scores
|
27 |
-
X_propensity = df[covariates]
|
28 |
-
T_propensity = df[treatment_col]
|
29 |
-
prop_model = LogisticRegression(solver='liblinear')
|
30 |
-
prop_model.fit(X_propensity, T_propensity)
|
31 |
-
propensity_scores = prop_model.predict_proba(X_propensity)[:, 1]
|
32 |
-
|
33 |
-
# Dummy ATE calculation for demonstration
|
34 |
-
treated_outcome = df[df[treatment_col] == 1][outcome_col].mean()
|
35 |
-
control_outcome = df[df[treatment_col] == 0][outcome_col].mean()
|
36 |
-
return float(treated_outcome - control_outcome) # Simplified dummy ATE
|
37 |
-
|
38 |
-
def inverse_propensity_weighting(self, df, treatment_col, outcome_col, covariates):
|
39 |
-
"""
|
40 |
-
Placeholder for Inverse Propensity Weighting (IPW).
|
41 |
-
You would implement or integrate IPW here.
|
42 |
-
"""
|
43 |
-
print("Inverse Propensity Weighting is a placeholder. Returning a dummy ATE.")
|
44 |
-
# Dummy ATE for demonstration
|
45 |
-
return np.random.rand() * 10 # Random dummy value
|
46 |
-
|
47 |
-
def t_learner(self, df, treatment_col, outcome_col, covariates):
|
48 |
-
"""
|
49 |
-
Placeholder for T-learner.
|
50 |
-
You would implement a T-learner using two separate models.
|
51 |
-
"""
|
52 |
-
print("T-learner is a placeholder. Returning a dummy ATE.")
|
53 |
-
# Dummy ATE for demonstration
|
54 |
-
return np.random.rand() * 10 + 5 # Random dummy value
|
55 |
-
|
56 |
-
def s_learner(self, df, treatment_col, outcome_col, covariates):
|
57 |
-
"""
|
58 |
-
Placeholder for S-learner.
|
59 |
-
You would implement an S-learner using a single model.
|
60 |
-
"""
|
61 |
-
print("S-learner is a placeholder. Returning a dummy ATE.")
|
62 |
-
# Dummy ATE for demonstration
|
63 |
return np.random.rand() * 10 - 2 # Random dummy value
|
|
|
1 |
+
# utils/treatment_effects.py
|
2 |
+
from sklearn.linear_model import LinearRegression, LogisticRegression
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
# For matching-based methods, you might need libraries like dowhy or causalml
|
6 |
+
# import statsmodels.api as sm # Example for regression diagnostics
|
7 |
+
|
8 |
+
class TreatmentEffectAlgorithms:
|
9 |
+
def linear_regression_ate(self, df, treatment_col, outcome_col, covariates):
|
10 |
+
"""
|
11 |
+
Estimate ATE using linear regression.
|
12 |
+
"""
|
13 |
+
X = df[covariates + [treatment_col]]
|
14 |
+
y = df[outcome_col]
|
15 |
+
model = LinearRegression()
|
16 |
+
model.fit(X, y)
|
17 |
+
ate = model.coef_[-1] # Coefficient of treatment_col
|
18 |
+
return float(ate)
|
19 |
+
|
20 |
+
def propensity_score_matching(self, df, treatment_col, outcome_col, covariates):
|
21 |
+
"""
|
22 |
+
Placeholder for Propensity Score Matching.
|
23 |
+
You would implement or integrate a matching algorithm here.
|
24 |
+
"""
|
25 |
+
print("Propensity Score Matching is a placeholder. Returning a dummy ATE.")
|
26 |
+
# Simplified: Estimate propensity scores
|
27 |
+
X_propensity = df[covariates]
|
28 |
+
T_propensity = df[treatment_col]
|
29 |
+
prop_model = LogisticRegression(solver='liblinear')
|
30 |
+
prop_model.fit(X_propensity, T_propensity)
|
31 |
+
propensity_scores = prop_model.predict_proba(X_propensity)[:, 1]
|
32 |
+
|
33 |
+
# Dummy ATE calculation for demonstration
|
34 |
+
treated_outcome = df[df[treatment_col] == 1][outcome_col].mean()
|
35 |
+
control_outcome = df[df[treatment_col] == 0][outcome_col].mean()
|
36 |
+
return float(treated_outcome - control_outcome) # Simplified dummy ATE
|
37 |
+
|
38 |
+
def inverse_propensity_weighting(self, df, treatment_col, outcome_col, covariates):
|
39 |
+
"""
|
40 |
+
Placeholder for Inverse Propensity Weighting (IPW).
|
41 |
+
You would implement or integrate IPW here.
|
42 |
+
"""
|
43 |
+
print("Inverse Propensity Weighting is a placeholder. Returning a dummy ATE.")
|
44 |
+
# Dummy ATE for demonstration
|
45 |
+
return np.random.rand() * 10 # Random dummy value
|
46 |
+
|
47 |
+
def t_learner(self, df, treatment_col, outcome_col, covariates):
|
48 |
+
"""
|
49 |
+
Placeholder for T-learner.
|
50 |
+
You would implement a T-learner using two separate models.
|
51 |
+
"""
|
52 |
+
print("T-learner is a placeholder. Returning a dummy ATE.")
|
53 |
+
# Dummy ATE for demonstration
|
54 |
+
return np.random.rand() * 10 + 5 # Random dummy value
|
55 |
+
|
56 |
+
def s_learner(self, df, treatment_col, outcome_col, covariates):
|
57 |
+
"""
|
58 |
+
Placeholder for S-learner.
|
59 |
+
You would implement an S-learner using a single model.
|
60 |
+
"""
|
61 |
+
print("S-learner is a placeholder. Returning a dummy ATE.")
|
62 |
+
# Dummy ATE for demonstration
|
63 |
return np.random.rand() * 10 - 2 # Random dummy value
|