ShutterStack commited on
Commit
ab66d4e
·
verified ·
1 Parent(s): b8611e0

major changes

Browse files
Files changed (39) hide show
  1. .env +1 -0
  2. .gitattributes +35 -35
  3. data/sample_dataset.csv +0 -0
  4. main.py +45 -37
  5. requirements.txt +14 -11
  6. routers/__pycache__/chatbot_routes.cpython-310.pyc +0 -0
  7. routers/__pycache__/discover_routes.cpython-310.pyc +0 -0
  8. routers/__pycache__/intervene_routes.cpython-310.pyc +0 -0
  9. routers/__pycache__/prediction_routes.cpython-310.pyc +0 -0
  10. routers/__pycache__/preprocess_routes.cpython-310.pyc +0 -0
  11. routers/__pycache__/timeseries_routes.cpython-310.pyc +0 -0
  12. routers/__pycache__/treatment_routes.cpython-310.pyc +0 -0
  13. routers/__pycache__/visualize_routes.cpython-310.pyc +0 -0
  14. routers/chatbot_routes.py +25 -0
  15. routers/discover_routes.py +42 -42
  16. routers/intervene_routes.py +53 -53
  17. routers/prediction_routes.py +27 -0
  18. routers/preprocess_routes.py +55 -55
  19. routers/timeseries_routes.py +30 -0
  20. routers/treatment_routes.py +53 -53
  21. routers/visualize_routes.py +42 -42
  22. scripts/generate_data.py +29 -29
  23. streamlit_app.py +618 -307
  24. utils/__pycache__/casual_algorithms.cpython-310.pyc +0 -0
  25. utils/__pycache__/causal_chatbot.cpython-310.pyc +0 -0
  26. utils/__pycache__/do_calculus.cpython-310.pyc +0 -0
  27. utils/__pycache__/graph_utils.cpython-310.pyc +0 -0
  28. utils/__pycache__/prediction_models.cpython-310.pyc +0 -0
  29. utils/__pycache__/preprocessor.cpython-310.pyc +0 -0
  30. utils/__pycache__/time_series_causal.cpython-310.pyc +0 -0
  31. utils/__pycache__/treatment_effects.cpython-310.pyc +0 -0
  32. utils/casual_algorithms.py +63 -63
  33. utils/causal_chatbot.py +271 -0
  34. utils/do_calculus.py +51 -51
  35. utils/graph_utils.py +107 -60
  36. utils/prediction_models.py +86 -0
  37. utils/preprocessor.py +88 -57
  38. utils/time_series_causal.py +102 -0
  39. utils/treatment_effects.py +62 -62
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ GROQ_API_KEY=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
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
data/sample_dataset.csv CHANGED
The diff for this file is too large to render. See raw diff
 
main.py CHANGED
@@ -1,38 +1,46 @@
1
- # main.py
2
- from flask import Flask, jsonify, request
3
- from flask_cors import CORS
4
- import os
5
- import sys
6
-
7
- # Add the 'routers' and 'utils' directories to the Python path
8
- # This allows direct imports like 'from routers.preprocess_routes import preprocess_bp'
9
- script_dir = os.path.dirname(__file__)
10
- sys.path.insert(0, os.path.join(script_dir, 'routers'))
11
- sys.path.insert(0, os.path.join(script_dir, 'utils'))
12
-
13
- # Import Blueprints
14
- from routers.preprocess_routes import preprocess_bp
15
- from routers.discover_routes import discover_bp
16
- from routers.intervene_routes import intervene_bp
17
- from routers.treatment_routes import treatment_bp
18
- from routers.visualize_routes import visualize_bp
19
-
20
- app = Flask(__name__)
21
- CORS(app) # Enable CORS for frontend interaction
22
-
23
- # Register Blueprints
24
- app.register_blueprint(preprocess_bp, url_prefix='/preprocess')
25
- app.register_blueprint(discover_bp, url_prefix='/discover')
26
- app.register_blueprint(intervene_bp, url_prefix='/intervene')
27
- app.register_blueprint(treatment_bp, url_prefix='/treatment')
28
- app.register_blueprint(visualize_bp, url_prefix='/visualize')
29
-
30
- @app.route('/')
31
- def home():
32
- return "Welcome to CausalBox Backend API!"
33
-
34
- if __name__ == '__main__':
35
- # Ensure the 'data' directory exists for storing datasets
36
- os.makedirs('data', exist_ok=True)
37
- # Run the Flask app
 
 
 
 
 
 
 
 
38
  app.run(debug=True, host='0.0.0.0', port=5000)
 
1
+ # main.py
2
+ from flask import Flask, jsonify, request
3
+ from flask_cors import CORS
4
+ import os
5
+ import sys
6
+ from dotenv import load_dotenv
7
+ load_dotenv()
8
+
9
+ # Add the 'routers' and 'utils' directories to the Python path
10
+ # This allows direct imports like 'from routers.preprocess_routes import preprocess_bp'
11
+ script_dir = os.path.dirname(__file__)
12
+ sys.path.insert(0, os.path.join(script_dir, 'routers'))
13
+ sys.path.insert(0, os.path.join(script_dir, 'utils'))
14
+
15
+ # Import Blueprints
16
+ from routers.preprocess_routes import preprocess_bp
17
+ from routers.discover_routes import discover_bp
18
+ from routers.intervene_routes import intervene_bp
19
+ from routers.treatment_routes import treatment_bp
20
+ from routers.visualize_routes import visualize_bp
21
+ from routers.prediction_routes import prediction_bp
22
+ from routers.timeseries_routes import timeseries_bp
23
+ from routers.chatbot_routes import chatbot_bp
24
+
25
+ app = Flask(__name__)
26
+ CORS(app) # Enable CORS for frontend interaction
27
+
28
+ # Register Blueprints
29
+ app.register_blueprint(preprocess_bp, url_prefix='/preprocess')
30
+ app.register_blueprint(discover_bp, url_prefix='/discover')
31
+ app.register_blueprint(intervene_bp, url_prefix='/intervene')
32
+ app.register_blueprint(treatment_bp, url_prefix='/treatment')
33
+ app.register_blueprint(visualize_bp, url_prefix='/visualize')
34
+ app.register_blueprint(prediction_bp, url_prefix='/prediction')
35
+ app.register_blueprint(timeseries_bp, url_prefix='/timeseries')
36
+ app.register_blueprint(chatbot_bp, url_prefix='/chatbot')
37
+
38
+ @app.route('/')
39
+ def home():
40
+ return "Welcome to CausalBox Backend API!"
41
+
42
+ if __name__ == '__main__':
43
+ # Ensure the 'data' directory exists for storing datasets
44
+ os.makedirs('data', exist_ok=True)
45
+ # Run the Flask app
46
  app.run(debug=True, host='0.0.0.0', port=5000)
requirements.txt CHANGED
@@ -1,11 +1,14 @@
1
- Flask
2
- flask-cors
3
- pandas
4
- numpy
5
- scikit-learn
6
- causal-learn # For PC algorithm
7
- networkx
8
- plotly
9
- streamlit
10
- requests # For Streamlit to communicate with Flask
11
- watchfiles # For auto_refresh.py (if implemented for background tasks)
 
 
 
 
1
+ Flask
2
+ flask-cors
3
+ pandas
4
+ numpy
5
+ scikit-learn
6
+ causal-learn # For PC algorithm
7
+ networkx
8
+ plotly
9
+ streamlit
10
+ requests # For Streamlit to communicate with Flask
11
+ watchfiles # For auto_refresh.py (if implemented for background tasks)
12
+ statsmodels # For statistical models and tests
13
+ google-generativeai
14
+ python-dotenv
routers/__pycache__/chatbot_routes.cpython-310.pyc ADDED
Binary file (990 Bytes). View file
 
routers/__pycache__/discover_routes.cpython-310.pyc CHANGED
Binary files a/routers/__pycache__/discover_routes.cpython-310.pyc and b/routers/__pycache__/discover_routes.cpython-310.pyc differ
 
routers/__pycache__/intervene_routes.cpython-310.pyc CHANGED
Binary files a/routers/__pycache__/intervene_routes.cpython-310.pyc and b/routers/__pycache__/intervene_routes.cpython-310.pyc differ
 
routers/__pycache__/prediction_routes.cpython-310.pyc ADDED
Binary file (1.13 kB). View file
 
routers/__pycache__/preprocess_routes.cpython-310.pyc CHANGED
Binary files a/routers/__pycache__/preprocess_routes.cpython-310.pyc and b/routers/__pycache__/preprocess_routes.cpython-310.pyc differ
 
routers/__pycache__/timeseries_routes.cpython-310.pyc ADDED
Binary file (1.29 kB). View file
 
routers/__pycache__/treatment_routes.cpython-310.pyc CHANGED
Binary files a/routers/__pycache__/treatment_routes.cpython-310.pyc and b/routers/__pycache__/treatment_routes.cpython-310.pyc differ
 
routers/__pycache__/visualize_routes.cpython-310.pyc CHANGED
Binary files a/routers/__pycache__/visualize_routes.cpython-310.pyc and b/routers/__pycache__/visualize_routes.cpython-310.pyc differ
 
routers/chatbot_routes.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # routers/chatbot_routes.py
2
+ from flask import Blueprint, request, jsonify
3
+ from utils.causal_chatbot import get_chatbot_response # Import the core chatbot logic
4
+
5
+ chatbot_bp = Blueprint('chatbot_bp', __name__)
6
+
7
+ @chatbot_bp.route('/message', methods=['POST'])
8
+ def handle_chat_message():
9
+ """
10
+ API endpoint for the chatbot to receive user messages and provide responses.
11
+ """
12
+ data = request.json
13
+ user_message = data.get('user_message')
14
+ # Session context includes processed_data, causal_graph_adj, etc.
15
+ session_context = data.get('session_context', {})
16
+
17
+ if not user_message:
18
+ return jsonify({"detail": "No user message provided."}), 400
19
+
20
+ try:
21
+ response_text = get_chatbot_response(user_message, session_context)
22
+ return jsonify({"response": response_text}), 200
23
+ except Exception as e:
24
+ print(f"Error in chatbot route: {e}")
25
+ return jsonify({"detail": f"An error occurred in the chatbot: {str(e)}"}), 500
routers/discover_routes.py CHANGED
@@ -1,43 +1,43 @@
1
- # routers/discover_routes.py
2
- from flask import Blueprint, request, jsonify
3
- import pandas as pd
4
- from utils.casual_algorithms import CausalDiscoveryAlgorithms
5
- import logging
6
-
7
- discover_bp = Blueprint('discover', __name__)
8
- logger = logging.getLogger(__name__)
9
-
10
- causal_discovery_algorithms = CausalDiscoveryAlgorithms()
11
-
12
- @discover_bp.route('/', methods=['POST'])
13
- def discover_causal_graph():
14
- """
15
- Discover causal graph from input data using selected algorithm.
16
- Expects 'data' key with list of dicts (preprocessed DataFrame records) and 'algorithm' string.
17
- Returns graph as adjacency matrix.
18
- """
19
- try:
20
- payload = request.json
21
- if not payload or 'data' not in payload:
22
- return jsonify({"detail": "Invalid request payload: 'data' key missing."}), 400
23
-
24
- df = pd.DataFrame(payload["data"])
25
- algorithm = payload.get("algorithm", "pc").lower() # Default to PC
26
-
27
- logger.info(f"Received discovery request with algorithm: {algorithm}, data shape: {df.shape}")
28
-
29
- if algorithm == "pc":
30
- adj_matrix = causal_discovery_algorithms.pc_algorithm(df)
31
- elif algorithm == "ges":
32
- adj_matrix = causal_discovery_algorithms.ges_algorithm(df) # Placeholder
33
- elif algorithm == "notears":
34
- adj_matrix = causal_discovery_algorithms.notears_algorithm(df) # Placeholder
35
- else:
36
- return jsonify({"detail": f"Unsupported causal discovery algorithm: {algorithm}"}), 400
37
-
38
- logger.info(f"Causal graph discovered using {algorithm}.")
39
- return jsonify({"graph": adj_matrix.tolist()})
40
-
41
- except Exception as e:
42
- logger.exception(f"Error in causal discovery: {str(e)}")
43
  return jsonify({"detail": f"Causal discovery failed: {str(e)}"}), 500
 
1
+ # routers/discover_routes.py
2
+ from flask import Blueprint, request, jsonify
3
+ import pandas as pd
4
+ from utils.casual_algorithms import CausalDiscoveryAlgorithms
5
+ import logging
6
+
7
+ discover_bp = Blueprint('discover', __name__)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ causal_discovery_algorithms = CausalDiscoveryAlgorithms()
11
+
12
+ @discover_bp.route('/', methods=['POST'])
13
+ def discover_causal_graph():
14
+ """
15
+ Discover causal graph from input data using selected algorithm.
16
+ Expects 'data' key with list of dicts (preprocessed DataFrame records) and 'algorithm' string.
17
+ Returns graph as adjacency matrix.
18
+ """
19
+ try:
20
+ payload = request.json
21
+ if not payload or 'data' not in payload:
22
+ return jsonify({"detail": "Invalid request payload: 'data' key missing."}), 400
23
+
24
+ df = pd.DataFrame(payload["data"])
25
+ algorithm = payload.get("algorithm", "pc").lower() # Default to PC
26
+
27
+ logger.info(f"Received discovery request with algorithm: {algorithm}, data shape: {df.shape}")
28
+
29
+ if algorithm == "pc":
30
+ adj_matrix = causal_discovery_algorithms.pc_algorithm(df)
31
+ elif algorithm == "ges":
32
+ adj_matrix = causal_discovery_algorithms.ges_algorithm(df) # Placeholder
33
+ elif algorithm == "notears":
34
+ adj_matrix = causal_discovery_algorithms.notears_algorithm(df) # Placeholder
35
+ else:
36
+ return jsonify({"detail": f"Unsupported causal discovery algorithm: {algorithm}"}), 400
37
+
38
+ logger.info(f"Causal graph discovered using {algorithm}.")
39
+ return jsonify({"graph": adj_matrix.tolist()})
40
+
41
+ except Exception as e:
42
+ logger.exception(f"Error in causal discovery: {str(e)}")
43
  return jsonify({"detail": f"Causal discovery failed: {str(e)}"}), 500
routers/intervene_routes.py CHANGED
@@ -1,54 +1,54 @@
1
- # routers/intervene_routes.py
2
- from flask import Blueprint, request, jsonify
3
- import pandas as pd
4
- from utils.do_calculus import DoCalculus # Will be used for more advanced intervention
5
- import networkx as nx # Assuming graph is passed or re-discovered
6
- import logging
7
-
8
- intervene_bp = Blueprint('intervene', __name__)
9
- logger = logging.getLogger(__name__)
10
-
11
- @intervene_bp.route('/', methods=['POST'])
12
- def perform_intervention():
13
- """
14
- Perform causal intervention on data.
15
- Expects 'data' (list of dicts), 'intervention_var' (column name),
16
- 'intervention_value' (numeric), and optionally 'graph' (adjacency matrix).
17
- Returns intervened data as list of dicts.
18
- """
19
- try:
20
- payload = request.json
21
- if not payload or 'data' not in payload or 'intervention_var' not in payload or 'intervention_value' not in payload:
22
- return jsonify({"detail": "Missing required intervention parameters."}), 400
23
-
24
- df = pd.DataFrame(payload["data"])
25
- intervention_var = payload["intervention_var"]
26
- intervention_value = payload["intervention_value"]
27
- graph_adj_matrix = payload.get("graph") # Optional: pass pre-discovered graph
28
-
29
- logger.info(f"Intervention request: var={intervention_var}, value={intervention_value}, data shape: {df.shape}")
30
-
31
- if intervention_var not in df.columns:
32
- return jsonify({"detail": f"Intervention variable '{intervention_var}' not found in data"}), 400
33
-
34
- # For a more advanced do-calculus, you'd need the graph structure.
35
- # Here, a simplified direct intervention is applied first.
36
- # If graph_adj_matrix is provided, you could convert it to networkx.
37
- # For full do-calculus, the DoCalculus class would need a proper graph.
38
-
39
- df_intervened = df.copy()
40
- df_intervened[intervention_var] = intervention_value
41
-
42
- # Placeholder for propagating effects using a graph if provided
43
- # if graph_adj_matrix:
44
- # graph_nx = nx.from_numpy_array(np.array(graph_adj_matrix), create_using=nx.DiGraph)
45
- # do_calculus_engine = DoCalculus(graph_nx)
46
- # df_intervened = do_calculus_engine.intervene(df_intervened, intervention_var, intervention_value)
47
- # logger.info("Propagated effects using do-calculus (simplified).")
48
-
49
- logger.info(f"Intervened data shape: {df_intervened.shape}")
50
- return jsonify({"intervened_data": df_intervened.to_dict(orient="records")})
51
-
52
- except Exception as e:
53
- logger.exception(f"Error in intervention: {str(e)}")
54
  return jsonify({"detail": f"Intervention failed: {str(e)}"}), 500
 
1
+ # routers/intervene_routes.py
2
+ from flask import Blueprint, request, jsonify
3
+ import pandas as pd
4
+ from utils.do_calculus import DoCalculus # Will be used for more advanced intervention
5
+ import networkx as nx # Assuming graph is passed or re-discovered
6
+ import logging
7
+
8
+ intervene_bp = Blueprint('intervene', __name__)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ @intervene_bp.route('/', methods=['POST'])
12
+ def perform_intervention():
13
+ """
14
+ Perform causal intervention on data.
15
+ Expects 'data' (list of dicts), 'intervention_var' (column name),
16
+ 'intervention_value' (numeric), and optionally 'graph' (adjacency matrix).
17
+ Returns intervened data as list of dicts.
18
+ """
19
+ try:
20
+ payload = request.json
21
+ if not payload or 'data' not in payload or 'intervention_var' not in payload or 'intervention_value' not in payload:
22
+ return jsonify({"detail": "Missing required intervention parameters."}), 400
23
+
24
+ df = pd.DataFrame(payload["data"])
25
+ intervention_var = payload["intervention_var"]
26
+ intervention_value = payload["intervention_value"]
27
+ graph_adj_matrix = payload.get("graph") # Optional: pass pre-discovered graph
28
+
29
+ logger.info(f"Intervention request: var={intervention_var}, value={intervention_value}, data shape: {df.shape}")
30
+
31
+ if intervention_var not in df.columns:
32
+ return jsonify({"detail": f"Intervention variable '{intervention_var}' not found in data"}), 400
33
+
34
+ # For a more advanced do-calculus, you'd need the graph structure.
35
+ # Here, a simplified direct intervention is applied first.
36
+ # If graph_adj_matrix is provided, you could convert it to networkx.
37
+ # For full do-calculus, the DoCalculus class would need a proper graph.
38
+
39
+ df_intervened = df.copy()
40
+ df_intervened[intervention_var] = intervention_value
41
+
42
+ # Placeholder for propagating effects using a graph if provided
43
+ # if graph_adj_matrix:
44
+ # graph_nx = nx.from_numpy_array(np.array(graph_adj_matrix), create_using=nx.DiGraph)
45
+ # do_calculus_engine = DoCalculus(graph_nx)
46
+ # df_intervened = do_calculus_engine.intervene(df_intervened, intervention_var, intervention_value)
47
+ # logger.info("Propagated effects using do-calculus (simplified).")
48
+
49
+ logger.info(f"Intervened data shape: {df_intervened.shape}")
50
+ return jsonify({"intervened_data": df_intervened.to_dict(orient="records")})
51
+
52
+ except Exception as e:
53
+ logger.exception(f"Error in intervention: {str(e)}")
54
  return jsonify({"detail": f"Intervention failed: {str(e)}"}), 500
routers/prediction_routes.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # routers/prediction_routes.py
2
+ from flask import Blueprint, request, jsonify
3
+ import pandas as pd
4
+ from utils.prediction_models import train_predict_random_forest
5
+
6
+ prediction_bp = Blueprint('prediction_bp', __name__)
7
+
8
+ @prediction_bp.route('/train_predict', methods=['POST'])
9
+ def train_predict():
10
+ """
11
+ API endpoint to train a Random Forest model and perform prediction/evaluation.
12
+ """
13
+ data = request.json.get('data')
14
+ target_col = request.json.get('target_col')
15
+ feature_cols = request.json.get('feature_cols')
16
+ prediction_type = request.json.get('prediction_type')
17
+
18
+ if not all([data, target_col, feature_cols, prediction_type]):
19
+ return jsonify({"detail": "Missing required parameters for prediction."}), 400
20
+
21
+ try:
22
+ results = train_predict_random_forest(data, target_col, feature_cols, prediction_type)
23
+ return jsonify({"results": results}), 200
24
+ except ValueError as e:
25
+ return jsonify({"detail": str(e)}), 400
26
+ except Exception as e:
27
+ return jsonify({"detail": f"An error occurred during prediction: {str(e)}"}), 500
routers/preprocess_routes.py CHANGED
@@ -1,56 +1,56 @@
1
- # routers/preprocess_routes.py
2
- from flask import Blueprint, request, jsonify
3
- import pandas as pd
4
- from utils.preprocessor import DataPreprocessor
5
- import logging
6
-
7
- preprocess_bp = Blueprint('preprocess', __name__)
8
-
9
- # Set up logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- preprocessor = DataPreprocessor()
14
-
15
- @preprocess_bp.route('/upload', methods=['POST'])
16
- def upload_file():
17
- """
18
- Upload and preprocess a CSV file.
19
- Returns preprocessed DataFrame columns and data as JSON.
20
- Optional limit_rows to reduce response size for testing.
21
- """
22
- if 'file' not in request.files:
23
- return jsonify({"detail": "No file part in the request"}), 400
24
- file = request.files['file']
25
- if file.filename == '':
26
- return jsonify({"detail": "No selected file"}), 400
27
- if not file.filename.lower().endswith('.csv'):
28
- return jsonify({"detail": "Only CSV files are supported"}), 400
29
-
30
- limit_rows = request.args.get('limit_rows', type=int)
31
-
32
- try:
33
- logger.info(f"Received file: {file.filename}")
34
- df = pd.read_csv(file)
35
- logger.info(f"CSV read successfully, shape: {df.shape}")
36
-
37
- processed_df = preprocessor.preprocess(df)
38
- if limit_rows:
39
- processed_df = processed_df.head(limit_rows)
40
- logger.info(f"Limited to {limit_rows} rows.")
41
-
42
- response = {
43
- "columns": list(processed_df.columns),
44
- "data": processed_df.to_dict(orient="records")
45
- }
46
- logger.info(f"Preprocessed {len(response['data'])} records.")
47
- return jsonify(response)
48
- except pd.errors.EmptyDataError:
49
- logger.error("Empty CSV file uploaded.")
50
- return jsonify({"detail": "Empty CSV file"}), 400
51
- except pd.errors.ParserError:
52
- logger.error("Invalid CSV format.")
53
- return jsonify({"detail": "Invalid CSV format"}), 400
54
- except Exception as e:
55
- logger.exception(f"Unexpected error during file processing: {str(e)}")
56
  return jsonify({"detail": f"Failed to process file: {str(e)}"}), 500
 
1
+ # routers/preprocess_routes.py
2
+ from flask import Blueprint, request, jsonify
3
+ import pandas as pd
4
+ from utils.preprocessor import DataPreprocessor
5
+ import logging
6
+
7
+ preprocess_bp = Blueprint('preprocess', __name__)
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ preprocessor = DataPreprocessor()
14
+
15
+ @preprocess_bp.route('/upload', methods=['POST'])
16
+ def upload_file():
17
+ """
18
+ Upload and preprocess a CSV file.
19
+ Returns preprocessed DataFrame columns and data as JSON.
20
+ Optional limit_rows to reduce response size for testing.
21
+ """
22
+ if 'file' not in request.files:
23
+ return jsonify({"detail": "No file part in the request"}), 400
24
+ file = request.files['file']
25
+ if file.filename == '':
26
+ return jsonify({"detail": "No selected file"}), 400
27
+ if not file.filename.lower().endswith('.csv'):
28
+ return jsonify({"detail": "Only CSV files are supported"}), 400
29
+
30
+ limit_rows = request.args.get('limit_rows', type=int)
31
+
32
+ try:
33
+ logger.info(f"Received file: {file.filename}")
34
+ df = pd.read_csv(file)
35
+ logger.info(f"CSV read successfully, shape: {df.shape}")
36
+
37
+ processed_df = preprocessor.preprocess(df)
38
+ if limit_rows:
39
+ processed_df = processed_df.head(limit_rows)
40
+ logger.info(f"Limited to {limit_rows} rows.")
41
+
42
+ response = {
43
+ "columns": list(processed_df.columns),
44
+ "data": processed_df.to_dict(orient="records")
45
+ }
46
+ logger.info(f"Preprocessed {len(response['data'])} records.")
47
+ return jsonify(response)
48
+ except pd.errors.EmptyDataError:
49
+ logger.error("Empty CSV file uploaded.")
50
+ return jsonify({"detail": "Empty CSV file"}), 400
51
+ except pd.errors.ParserError:
52
+ logger.error("Invalid CSV format.")
53
+ return jsonify({"detail": "Invalid CSV format"}), 400
54
+ except Exception as e:
55
+ logger.exception(f"Unexpected error during file processing: {str(e)}")
56
  return jsonify({"detail": f"Failed to process file: {str(e)}"}), 500
routers/timeseries_routes.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # routers/timeseries_routes.py
2
+ from flask import Blueprint, request, jsonify
3
+ import pandas as pd
4
+ from utils.time_series_causal import perform_granger_causality
5
+
6
+ timeseries_bp = Blueprint('timeseries_bp', __name__)
7
+
8
+ @timeseries_bp.route('/discover_causality', methods=['POST'])
9
+ def discover_timeseries_causality():
10
+ """
11
+ API endpoint to perform time-series causal discovery (Granger Causality).
12
+ """
13
+ data = request.json.get('data')
14
+ timestamp_col = request.json.get('timestamp_col')
15
+ variables_to_analyze = request.json.get('variables_to_analyze')
16
+ max_lags = request.json.get('max_lags', 1) # Default to 1 lag
17
+
18
+ if not all([data, timestamp_col, variables_to_analyze]):
19
+ return jsonify({"detail": "Missing required parameters for time-series causal discovery."}), 400
20
+
21
+ if not isinstance(max_lags, int) or max_lags <= 0:
22
+ return jsonify({"detail": "max_lags must be a positive integer."}), 400
23
+
24
+ try:
25
+ results = perform_granger_causality(data, timestamp_col, variables_to_analyze, max_lags)
26
+ return jsonify({"results": results}), 200
27
+ except ValueError as e:
28
+ return jsonify({"detail": str(e)}), 400
29
+ except Exception as e:
30
+ return jsonify({"detail": f"An error occurred during time-series causal discovery: {str(e)}"}), 500
routers/treatment_routes.py CHANGED
@@ -1,54 +1,54 @@
1
- # routers/treatment_routes.py
2
- from flask import Blueprint, request, jsonify
3
- import pandas as pd
4
- from utils.treatment_effects import TreatmentEffectAlgorithms
5
- import logging
6
-
7
- treatment_bp = Blueprint('treatment', __name__)
8
- logger = logging.getLogger(__name__)
9
-
10
- treatment_effect_algorithms = TreatmentEffectAlgorithms()
11
-
12
- @treatment_bp.route('/estimate_ate', methods=['POST'])
13
- def estimate_ate():
14
- """
15
- Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).
16
- Expects 'data' (list of dicts), 'treatment_col', 'outcome_col', 'covariates' (list of column names),
17
- and 'method' (string for estimation method).
18
- Returns ATE/CATE as float or dictionary.
19
- """
20
- try:
21
- payload = request.json
22
- if not payload or 'data' not in payload or 'treatment_col' not in payload or 'outcome_col' not in payload or 'covariates' not in payload:
23
- return jsonify({"detail": "Missing required ATE estimation parameters."}), 400
24
-
25
- df = pd.DataFrame(payload["data"])
26
- treatment_col = payload["treatment_col"]
27
- outcome_col = payload["outcome_col"]
28
- covariates = payload["covariates"]
29
- method = payload.get("method", "linear_regression").lower() # Default to linear regression
30
-
31
- logger.info(f"ATE/CATE request: treatment={treatment_col}, outcome={outcome_col}, method={method}, data shape: {df.shape}")
32
-
33
- if not all(col in df.columns for col in [treatment_col, outcome_col] + covariates):
34
- return jsonify({"detail": "Invalid column names provided for ATE estimation."}), 400
35
-
36
- if method == "linear_regression":
37
- result = treatment_effect_algorithms.linear_regression_ate(df, treatment_col, outcome_col, covariates)
38
- elif method == "propensity_score_matching":
39
- result = treatment_effect_algorithms.propensity_score_matching(df, treatment_col, outcome_col, covariates) # Placeholder
40
- elif method == "inverse_propensity_weighting":
41
- result = treatment_effect_algorithms.inverse_propensity_weighting(df, treatment_col, outcome_col, covariates) # Placeholder
42
- elif method == "t_learner":
43
- result = treatment_effect_algorithms.t_learner(df, treatment_col, outcome_col, covariates) # Placeholder
44
- elif method == "s_learner":
45
- result = treatment_effect_algorithms.s_learner(df, treatment_col, outcome_col, covariates) # Placeholder
46
- else:
47
- return jsonify({"detail": f"Unsupported treatment effect estimation method: {method}"}), 400
48
-
49
- logger.info(f"Estimated ATE/CATE using {method}: {result}")
50
- return jsonify({"result": result})
51
-
52
- except Exception as e:
53
- logger.exception(f"Error in ATE/CATE estimation: {str(e)}")
54
  return jsonify({"detail": f"ATE/CATE estimation failed: {str(e)}"}), 500
 
1
+ # routers/treatment_routes.py
2
+ from flask import Blueprint, request, jsonify
3
+ import pandas as pd
4
+ from utils.treatment_effects import TreatmentEffectAlgorithms
5
+ import logging
6
+
7
+ treatment_bp = Blueprint('treatment', __name__)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ treatment_effect_algorithms = TreatmentEffectAlgorithms()
11
+
12
+ @treatment_bp.route('/estimate_ate', methods=['POST'])
13
+ def estimate_ate():
14
+ """
15
+ Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).
16
+ Expects 'data' (list of dicts), 'treatment_col', 'outcome_col', 'covariates' (list of column names),
17
+ and 'method' (string for estimation method).
18
+ Returns ATE/CATE as float or dictionary.
19
+ """
20
+ try:
21
+ payload = request.json
22
+ if not payload or 'data' not in payload or 'treatment_col' not in payload or 'outcome_col' not in payload or 'covariates' not in payload:
23
+ return jsonify({"detail": "Missing required ATE estimation parameters."}), 400
24
+
25
+ df = pd.DataFrame(payload["data"])
26
+ treatment_col = payload["treatment_col"]
27
+ outcome_col = payload["outcome_col"]
28
+ covariates = payload["covariates"]
29
+ method = payload.get("method", "linear_regression").lower() # Default to linear regression
30
+
31
+ logger.info(f"ATE/CATE request: treatment={treatment_col}, outcome={outcome_col}, method={method}, data shape: {df.shape}")
32
+
33
+ if not all(col in df.columns for col in [treatment_col, outcome_col] + covariates):
34
+ return jsonify({"detail": "Invalid column names provided for ATE estimation."}), 400
35
+
36
+ if method == "linear_regression":
37
+ result = treatment_effect_algorithms.linear_regression_ate(df, treatment_col, outcome_col, covariates)
38
+ elif method == "propensity_score_matching":
39
+ result = treatment_effect_algorithms.propensity_score_matching(df, treatment_col, outcome_col, covariates) # Placeholder
40
+ elif method == "inverse_propensity_weighting":
41
+ result = treatment_effect_algorithms.inverse_propensity_weighting(df, treatment_col, outcome_col, covariates) # Placeholder
42
+ elif method == "t_learner":
43
+ result = treatment_effect_algorithms.t_learner(df, treatment_col, outcome_col, covariates) # Placeholder
44
+ elif method == "s_learner":
45
+ result = treatment_effect_algorithms.s_learner(df, treatment_col, outcome_col, covariates) # Placeholder
46
+ else:
47
+ return jsonify({"detail": f"Unsupported treatment effect estimation method: {method}"}), 400
48
+
49
+ logger.info(f"Estimated ATE/CATE using {method}: {result}")
50
+ return jsonify({"result": result})
51
+
52
+ except Exception as e:
53
+ logger.exception(f"Error in ATE/CATE estimation: {str(e)}")
54
  return jsonify({"detail": f"ATE/CATE estimation failed: {str(e)}"}), 500
routers/visualize_routes.py CHANGED
@@ -1,43 +1,43 @@
1
- # routers/visualize_routes.py
2
- from flask import Blueprint, request, jsonify
3
- import pandas as pd
4
- from utils.graph_utils import visualize_graph
5
- import networkx as nx
6
- import numpy as np
7
- import logging
8
-
9
- visualize_bp = Blueprint('visualize', __name__)
10
- logger = logging.getLogger(__name__)
11
-
12
- @visualize_bp.route('/graph', methods=['POST'])
13
- def get_graph_visualization():
14
- """
15
- Generate a causal graph visualization from an adjacency matrix.
16
- Expects 'graph' (adjacency matrix as list of lists) and 'nodes' (list of node names).
17
- Returns Plotly JSON for the graph.
18
- """
19
- try:
20
- payload = request.json
21
- if not payload or 'graph' not in payload or 'nodes' not in payload:
22
- return jsonify({"detail": "Missing 'graph' or 'nodes' in request payload."}), 400
23
-
24
- adj_matrix = np.array(payload["graph"])
25
- nodes = payload["nodes"]
26
-
27
- logger.info(f"Received graph visualization request for {len(nodes)} nodes.")
28
-
29
- # Reconstruct networkx graph from adjacency matrix and node names
30
- graph_nx = nx.from_numpy_array(adj_matrix, create_using=nx.DiGraph)
31
-
32
- # Map integer node labels back to original column names if necessary
33
- # Assuming nodes are ordered as they appear in the original dataframe or provided in 'nodes'
34
- mapping = {i: node_name for i, node_name in enumerate(nodes)}
35
- graph_nx = nx.relabel_nodes(graph_nx, mapping)
36
-
37
- graph_json = visualize_graph(graph_nx)
38
- logger.info("Generated graph visualization JSON.")
39
- return jsonify({"graph": graph_json})
40
-
41
- except Exception as e:
42
- logger.exception(f"Error generating graph visualization: {str(e)}")
43
  return jsonify({"detail": f"Failed to generate visualization: {str(e)}"}), 500
 
1
+ # routers/visualize_routes.py
2
+ from flask import Blueprint, request, jsonify
3
+ import pandas as pd
4
+ from utils.graph_utils import visualize_graph
5
+ import networkx as nx
6
+ import numpy as np
7
+ import logging
8
+
9
+ visualize_bp = Blueprint('visualize', __name__)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ @visualize_bp.route('/graph', methods=['POST'])
13
+ def get_graph_visualization():
14
+ """
15
+ Generate a causal graph visualization from an adjacency matrix.
16
+ Expects 'graph' (adjacency matrix as list of lists) and 'nodes' (list of node names).
17
+ Returns Plotly JSON for the graph.
18
+ """
19
+ try:
20
+ payload = request.json
21
+ if not payload or 'graph' not in payload or 'nodes' not in payload:
22
+ return jsonify({"detail": "Missing 'graph' or 'nodes' in request payload."}), 400
23
+
24
+ adj_matrix = np.array(payload["graph"])
25
+ nodes = payload["nodes"]
26
+
27
+ logger.info(f"Received graph visualization request for {len(nodes)} nodes.")
28
+
29
+ # Reconstruct networkx graph from adjacency matrix and node names
30
+ graph_nx = nx.from_numpy_array(adj_matrix, create_using=nx.DiGraph)
31
+
32
+ # Map integer node labels back to original column names if necessary
33
+ # Assuming nodes are ordered as they appear in the original dataframe or provided in 'nodes'
34
+ mapping = {i: node_name for i, node_name in enumerate(nodes)}
35
+ graph_nx = nx.relabel_nodes(graph_nx, mapping)
36
+
37
+ graph_json = visualize_graph(graph_nx)
38
+ logger.info("Generated graph visualization JSON.")
39
+ return jsonify({"graph": graph_json})
40
+
41
+ except Exception as e:
42
+ logger.exception(f"Error generating graph visualization: {str(e)}")
43
  return jsonify({"detail": f"Failed to generate visualization: {str(e)}"}), 500
scripts/generate_data.py CHANGED
@@ -1,29 +1,29 @@
1
- # scripts/generate_data.py
2
- import numpy as np
3
- import pandas as pd
4
- import os
5
-
6
- def generate_dataset(n_samples=1000):
7
- np.random.seed(42)
8
- study_hours = np.random.normal(10, 2, n_samples)
9
- tuition_hours = np.random.normal(5, 1, n_samples)
10
- parental_education = np.random.choice(['High', 'Medium', 'Low'], n_samples)
11
- school_type = np.random.choice(['Public', 'Private'], n_samples)
12
- exam_score = 50 + 2 * study_hours + 1.5 * tuition_hours + np.random.normal(0, 5, n_samples)
13
-
14
- df = pd.DataFrame({
15
- 'StudyHours': study_hours,
16
- 'TuitionHours': tuition_hours,
17
- 'ParentalEducation': parental_education,
18
- 'SchoolType': school_type,
19
- 'FinalExamScore': exam_score
20
- })
21
-
22
- # Ensure data directory exists
23
- os.makedirs('../data', exist_ok=True)
24
- df.to_csv('../data/sample_dataset.csv', index=False)
25
- return df
26
-
27
- if __name__ == "__main__":
28
- generate_dataset()
29
- print("Dataset generated and saved to ../data/sample_dataset.csv")
 
1
+ # scripts/generate_data.py
2
+ import numpy as np
3
+ import pandas as pd
4
+ import os
5
+
6
+ def generate_dataset(n_samples=1000):
7
+ np.random.seed(42)
8
+ study_hours = np.random.normal(10, 2, n_samples)
9
+ tuition_hours = np.random.normal(5, 1, n_samples)
10
+ parental_education = np.random.choice(['High', 'Medium', 'Low'], n_samples)
11
+ school_type = np.random.choice(['Public', 'Private'], n_samples)
12
+ exam_score = 50 + 2 * study_hours + 1.5 * tuition_hours + np.random.normal(0, 5, n_samples)
13
+
14
+ df = pd.DataFrame({
15
+ 'StudyHours': study_hours,
16
+ 'TuitionHours': tuition_hours,
17
+ 'ParentalEducation': parental_education,
18
+ 'SchoolType': school_type,
19
+ 'FinalExamScore': exam_score
20
+ })
21
+
22
+ # Ensure data directory exists
23
+ os.makedirs('../data', exist_ok=True)
24
+ df.to_csv('../data/sample_dataset.csv', index=False)
25
+ return df
26
+
27
+ if __name__ == "__main__":
28
+ generate_dataset()
29
+ print("Dataset generated and saved to ../data/sample_dataset.csv")
streamlit_app.py CHANGED
@@ -1,308 +1,619 @@
1
- # streamlit_app.py
2
- import streamlit as st
3
- import pandas as pd
4
- import requests
5
- import json
6
- import plotly.express as px
7
- import plotly.graph_objects as go
8
- import numpy as np # For random array in placeholders
9
- import os
10
-
11
- # Configuration
12
- FLASK_API_URL = "http://localhost:5000" # Ensure this matches your Flask app's host and port
13
-
14
- st.set_page_config(layout="wide", page_title="CausalBox Toolkit")
15
-
16
- st.title("🔬 CausalBox: A Causal Inference Toolkit")
17
- st.markdown("Uncover causal relationships, simulate interventions, and estimate treatment effects.")
18
-
19
- # --- Session State Initialization ---
20
- if 'processed_data' not in st.session_state:
21
- st.session_state.processed_data = None
22
- if 'processed_columns' not in st.session_state:
23
- st.session_state.processed_columns = None
24
- if 'causal_graph_adj' not in st.session_state:
25
- st.session_state.causal_graph_adj = None
26
- if 'causal_graph_nodes' not in st.session_state:
27
- st.session_state.causal_graph_nodes = None
28
-
29
- # --- Data Preprocessing Module ---
30
- st.header("1. Data Preprocessor 🧹")
31
- st.write("Upload your CSV dataset or use a generated sample dataset.")
32
-
33
- # Option to use generated sample dataset
34
- if st.button("Use Sample Dataset (sample_dataset.csv)"):
35
- # In a real scenario, Streamlit would serve the file or you'd load it directly if local.
36
- # For this setup, we assume the Flask backend can access it or you manually upload it once.
37
- # For demonstration, we'll simulate loading a generic DataFrame.
38
- # In a full deployment, you'd have a mechanism to either:
39
- # a) Have Flask serve the sample file, or
40
- # b) Directly load it in Streamlit if the app and data are co-located.
41
- try:
42
- # Assuming the sample dataset is accessible or you are testing locally with `scripts/generate_data.py`
43
- # and then manually uploading this generated file.
44
- # For simplicity, we'll create a dummy df here if not actually uploaded.
45
- sample_df_path = "data/sample_dataset.csv" # Path relative to main.py or Streamlit app execution
46
- if os.path.exists(sample_df_path):
47
- sample_df = pd.read_csv(sample_df_path)
48
- st.success(f"Loaded sample dataset from {sample_df_path}. Please upload this file if running from different directory.")
49
- else:
50
- st.warning("Sample dataset not found at data/sample_dataset.csv.")
51
- # Dummy DataFrame for demonstration if sample file isn't found
52
- sample_df = pd.DataFrame(np.random.rand(10, 5), columns=[f'col_{i}' for i in range(5)])
53
-
54
- # Convert to JSON for Flask API call
55
- files = {'file': ('sample_dataset.csv', sample_df.to_csv(index=False), 'text/csv')}
56
- response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
57
-
58
- if response.status_code == 200:
59
- result = response.json()
60
- st.session_state.processed_data = result['data']
61
- st.session_state.processed_columns = result['columns']
62
- st.success("Sample dataset preprocessed successfully!")
63
- st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
64
- else:
65
- st.error(f"Error preprocessing sample dataset: {response.json().get('detail', 'Unknown error')}")
66
- except Exception as e:
67
- st.error(f"Could not load or process sample dataset: {e}")
68
-
69
-
70
- uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
71
- if uploaded_file is not None:
72
- st.info("Uploading and preprocessing data...")
73
- files = {'file': (uploaded_file.name, uploaded_file.getvalue(), 'text/csv')}
74
- try:
75
- response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
76
- if response.status_code == 200:
77
- result = response.json()
78
- st.session_state.processed_data = result['data']
79
- st.session_state.processed_columns = result['columns']
80
- st.success("File preprocessed successfully!")
81
- st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
82
- else:
83
- st.error(f"Error during preprocessing: {response.json().get('detail', 'Unknown error')}")
84
- except requests.exceptions.ConnectionError:
85
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
86
- except Exception as e:
87
- st.error(f"An unexpected error occurred: {e}")
88
-
89
- # --- Causal Discovery Module ---
90
- st.header("2. Causal Discovery 🕵️‍♂️")
91
- if st.session_state.processed_data:
92
- st.write("Learn the causal structure from your preprocessed data.")
93
-
94
- discovery_algo = st.selectbox(
95
- "Select Causal Discovery Algorithm:",
96
- ("PC Algorithm", "GES (Greedy Equivalence Search) - Placeholder", "NOTEARS - Placeholder")
97
- )
98
-
99
- if st.button("Discover Causal Graph"):
100
- st.info(f"Discovering graph using {discovery_algo}...")
101
- algo_map = {
102
- "PC Algorithm": "pc",
103
- "GES (Greedy Equivalence Search) - Placeholder": "ges",
104
- "NOTEARS - Placeholder": "notears"
105
- }
106
- selected_algo_code = algo_map[discovery_algo]
107
-
108
- try:
109
- response = requests.post(
110
- f"{FLASK_API_URL}/discover/",
111
- json={"data": st.session_state.processed_data, "algorithm": selected_algo_code}
112
- )
113
- if response.status_code == 200:
114
- result = response.json()
115
- st.session_state.causal_graph_adj = result['graph']
116
- st.session_state.causal_graph_nodes = st.session_state.processed_columns
117
- st.success("Causal graph discovered!")
118
- st.subheader("Causal Graph Visualization")
119
- # Visualization will be handled by the Causal Graph Visualizer section
120
- else:
121
- st.error(f"Error during causal discovery: {response.json().get('detail', 'Unknown error')}")
122
- except requests.exceptions.ConnectionError:
123
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
124
- except Exception as e:
125
- st.error(f"An unexpected error occurred: {e}")
126
- else:
127
- st.info("Please preprocess data first to enable causal discovery.")
128
-
129
- # --- Causal Graph Visualizer Module ---
130
- st.header("3. Causal Graph Visualizer 📊")
131
- if st.session_state.causal_graph_adj and st.session_state.causal_graph_nodes:
132
- st.write("Interactive visualization of the discovered causal graph.")
133
- try:
134
- response = requests.post(
135
- f"{FLASK_API_URL}/visualize/graph",
136
- json={"graph": st.session_state.causal_graph_adj, "nodes": st.session_state.causal_graph_nodes}
137
- )
138
- if response.status_code == 200:
139
- graph_json = response.json()['graph']
140
- fig = go.Figure(json.loads(graph_json))
141
- st.plotly_chart(fig, use_container_width=True)
142
- st.markdown("""
143
- **Graph Explanation:**
144
- * **Nodes:** Represent variables in your dataset.
145
- * **Arrows (Edges):** Indicate a direct causal influence from one variable (the tail) to another (the head).
146
- * **No Arrow:** Suggests no direct causal relationship was found, or the relationship is mediated by other variables.
147
-
148
- This graph helps answer "Why did it happen?" by showing the structural relationships.
149
- """)
150
- else:
151
- st.error(f"Error visualizing graph: {response.json().get('detail', 'Unknown error')}")
152
- except requests.exceptions.ConnectionError:
153
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
154
- except Exception as e:
155
- st.error(f"An unexpected error occurred during visualization: {e}")
156
- else:
157
- st.info("Please discover a causal graph first to visualize it.")
158
-
159
-
160
- # --- Do-Calculus Engine Module ---
161
- st.header("4. Do-Calculus Engine 🧪")
162
- if st.session_state.processed_data and st.session_state.causal_graph_adj:
163
- st.write("Simulate interventions and observe their effects based on the causal graph.")
164
-
165
- intervention_var = st.selectbox(
166
- "Select variable to intervene on:",
167
- st.session_state.processed_columns,
168
- key="inter_var_select"
169
- )
170
- # Attempt to infer type for intervention_value input
171
- # Simplified approach: assuming numerical for now due to preprocessor output
172
- if intervention_var and isinstance(st.session_state.processed_data[0][intervention_var], (int, float)):
173
- intervention_value = st.number_input(f"Set '{intervention_var}' to value:", key="inter_val_input")
174
- else: # Treat as string/categorical for input, then try to preprocess for API
175
- intervention_value = st.text_input(f"Set '{intervention_var}' to value:", key="inter_val_input_text")
176
- st.warning("Categorical intervention values might require specific encoding logic on the backend.")
177
-
178
- if st.button("Perform Intervention"):
179
- st.info(f"Performing intervention: do('{intervention_var}' = {intervention_value})...")
180
- try:
181
- response = requests.post(
182
- f"{FLASK_API_URL}/intervene/",
183
- json={
184
- "data": st.session_state.processed_data,
185
- "intervention_var": intervention_var,
186
- "intervention_value": intervention_value,
187
- "graph": st.session_state.causal_graph_adj # Pass graph for advanced do-calculus
188
- }
189
- )
190
- if response.status_code == 200:
191
- intervened_data = pd.DataFrame(response.json()['intervened_data'])
192
- st.success("Intervention simulated successfully!")
193
- st.subheader("Intervened Data (First 10 rows)")
194
- st.dataframe(intervened_data.head(10))
195
-
196
- # Simple comparison visualization (e.g., histogram of outcome variable)
197
- if st.session_state.processed_columns and 'FinalExamScore' in st.session_state.processed_columns:
198
- original_df = pd.DataFrame(st.session_state.processed_data)
199
- fig_dist = go.Figure()
200
- fig_dist.add_trace(go.Histogram(x=original_df['FinalExamScore'], name='Original', opacity=0.7))
201
- fig_dist.add_trace(go.Histogram(x=intervened_data['FinalExamScore'], name='Intervened', opacity=0.0))
202
-
203
- st.plotly_chart(fig_dist, use_container_width=True)
204
- st.markdown("""
205
- **Intervention Explanation:**
206
- * By simulating `do(X=x)`, we are forcing the value of X, effectively breaking its causal links from its parents.
207
- * The graph above shows the distribution of a key outcome variable (e.g., `FinalExamScore`) before and after the intervention.
208
- * This helps answer "What if we do this instead?" by showing the predicted outcome.
209
- """)
210
- else:
211
- st.info("Consider adding a relevant outcome variable to your dataset for better intervention analysis.")
212
- else:
213
- st.error(f"Error during intervention: {response.json().get('detail', 'Unknown error')}")
214
- except requests.exceptions.ConnectionError:
215
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
216
- except Exception as e:
217
- st.error(f"An unexpected error occurred during intervention: {e}")
218
- else:
219
- st.info("Please preprocess data and discover a causal graph first to perform interventions.")
220
-
221
- # --- Treatment Effect Estimator Module ---
222
- st.header("5. Treatment Effect Estimator 🎯")
223
- if st.session_state.processed_data:
224
- st.write("Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).")
225
-
226
- col1, col2 = st.columns(2)
227
- with col1:
228
- treatment_col = st.selectbox(
229
- "Select Treatment Variable:",
230
- st.session_state.processed_columns,
231
- key="treat_col_select"
232
- )
233
- with col2:
234
- outcome_col = st.selectbox(
235
- "Select Outcome Variable:",
236
- st.session_state.processed_columns,
237
- key="outcome_col_select"
238
- )
239
-
240
- all_cols_except_treat_outcome = [col for col in st.session_state.processed_columns if col not in [treatment_col, outcome_col]]
241
- covariates = st.multiselect(
242
- "Select Covariates (confounders):",
243
- all_cols_except_treat_outcome,
244
- default=all_cols_except_treat_outcome, # Default to all other columns
245
- key="covariates_select"
246
- )
247
-
248
- estimation_method = st.selectbox(
249
- "Select Estimation Method:",
250
- (
251
- "Linear Regression ATE",
252
- "Propensity Score Matching - Placeholder",
253
- "Inverse Propensity Weighting - Placeholder",
254
- "T-learner - Placeholder",
255
- "S-learner - Placeholder"
256
- )
257
- )
258
-
259
- if st.button("Estimate Treatment Effect"):
260
- st.info(f"Estimating treatment effect using {estimation_method}...")
261
- method_map = {
262
- "Linear Regression ATE": "linear_regression",
263
- "Propensity Score Matching - Placeholder": "propensity_score_matching",
264
- "Inverse Propensity Weighting - Placeholder": "inverse_propensity_weighting",
265
- "T-learner - Placeholder": "t_learner",
266
- "S-learner - Placeholder": "s_learner"
267
- }
268
- selected_method_code = method_map[estimation_method]
269
-
270
- try:
271
- response = requests.post(
272
- f"{FLASK_API_URL}/treatment/estimate_ate",
273
- json={
274
- "data": st.session_state.processed_data,
275
- "treatment_col": treatment_col,
276
- "outcome_col": outcome_col,
277
- "covariates": covariates,
278
- "method": selected_method_code
279
- }
280
- )
281
- if response.status_code == 200:
282
- ate_result = response.json()['result']
283
- st.success(f"Treatment effect estimated using {estimation_method}:")
284
- st.write(f"**Estimated ATE: {ate_result:.4f}**")
285
- st.markdown("""
286
- **Treatment Effect Explanation:**
287
- * **Average Treatment Effect (ATE):** Measures the average causal effect of a treatment (e.g., `StudyHours`) on an outcome (e.g., `FinalExamScore`) across the entire population.
288
- * It answers "How much does doing X cause a change in Y?".
289
- * This estimation attempts to control for confounders (variables that influence both treatment and outcome) to isolate the true causal effect.
290
- """)
291
- else:
292
- st.error(f"Error during ATE estimation: {response.json().get('detail', 'Unknown error')}")
293
- except requests.exceptions.ConnectionError:
294
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
295
- except Exception as e:
296
- st.error(f"An unexpected error occurred during ATE estimation: {e}")
297
- else:
298
- st.info("Please preprocess data first to estimate treatment effects.")
299
-
300
- # --- Optional Advanced Add-Ons (Future Considerations) ---
301
- st.header("Optional Advanced Add-Ons (Future Work) 🚀")
302
- st.markdown("""
303
- - **🔄 Auto-causal graph refresh if dataset updates:** This would involve setting up a background process (e.g., using `watchfiles` with a separate service or integrated carefully into Flask/Streamlit) that monitors changes to the source CSV file. Upon detection, it would re-run the preprocessing and causal discovery, updating the dashboard live. This requires more complex architecture (e.g., WebSockets for real-time updates to Streamlit or scheduled background tasks).
304
- - **🕰️ Time-Series Causal Discovery (e.g., Granger Causality):** This requires handling time-indexed data and implementing algorithms specifically designed for temporal causal relationships. It would involve a separate data input and discovery module.
305
- """)
306
-
307
- st.markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  st.info("Developed by CausalBox Team. For support, please contact us.")
 
1
+ # streamlit_app.py
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import requests
5
+ import json
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+ import numpy as np # For random array in placeholders
9
+ import os
10
+
11
+ # Configuration
12
+ FLASK_API_URL = "http://localhost:5000" # Ensure this matches your Flask app's host and port
13
+
14
+ st.set_page_config(layout="wide", page_title="CausalBox Toolkit")
15
+
16
+ st.title("🔬 CausalBox: A Causal Inference Toolkit")
17
+ st.markdown("Uncover causal relationships, simulate interventions, and estimate treatment effects.")
18
+
19
+ # --- Session State Initialization ---
20
+ if 'processed_data' not in st.session_state:
21
+ st.session_state.processed_data = None
22
+ if 'processed_columns' not in st.session_state:
23
+ st.session_state.processed_columns = None
24
+ if 'causal_graph_adj' not in st.session_state:
25
+ st.session_state.causal_graph_adj = None
26
+ if 'causal_graph_nodes' not in st.session_state:
27
+ st.session_state.causal_graph_nodes = None
28
+
29
+ # --- Data Preprocessing Module ---
30
+ st.header("1. Data Preprocessor 🧹")
31
+ st.write("Upload your CSV dataset or use a generated sample dataset.")
32
+
33
+ # Option to use generated sample dataset
34
+ if st.button("Use Sample Dataset (sample_dataset.csv)"):
35
+ # In a real scenario, Streamlit would serve the file or you'd load it directly if local.
36
+ # For this setup, we assume the Flask backend can access it or you manually upload it once.
37
+ # For demonstration, we'll simulate loading a generic DataFrame.
38
+ # In a full deployment, you'd have a mechanism to either:
39
+ # a) Have Flask serve the sample file, or
40
+ # b) Directly load it in Streamlit if the app and data are co-located.
41
+ try:
42
+ # Assuming the sample dataset is accessible or you are testing locally with `scripts/generate_data.py`
43
+ # and then manually uploading this generated file.
44
+ # For simplicity, we'll create a dummy df here if not actually uploaded.
45
+ sample_df_path = "data/sample_dataset.csv" # Path relative to main.py or Streamlit app execution
46
+ if os.path.exists(sample_df_path):
47
+ sample_df = pd.read_csv(sample_df_path)
48
+ st.success(f"Loaded sample dataset from {sample_df_path}. Please upload this file if running from different directory.")
49
+ else:
50
+ st.warning("Sample dataset not found at data/sample_dataset.csv.")
51
+ # Dummy DataFrame for demonstration if sample file isn't found
52
+ sample_df = pd.DataFrame(np.random.rand(10, 5), columns=[f'col_{i}' for i in range(5)])
53
+
54
+ # Convert to JSON for Flask API call
55
+ files = {'file': ('sample_dataset.csv', sample_df.to_csv(index=False), 'text/csv')}
56
+ response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
57
+
58
+ if response.status_code == 200:
59
+ result = response.json()
60
+ st.session_state.processed_data = result['data']
61
+ st.session_state.processed_columns = result['columns']
62
+ st.success("Sample dataset preprocessed successfully!")
63
+ st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
64
+ else:
65
+ st.error(f"Error preprocessing sample dataset: {response.json().get('detail', 'Unknown error')}")
66
+ except Exception as e:
67
+ st.error(f"Could not load or process sample dataset: {e}")
68
+
69
+
70
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
71
+ if uploaded_file is not None:
72
+ st.info("Uploading and preprocessing data...")
73
+ files = {'file': (uploaded_file.name, uploaded_file.getvalue(), 'text/csv')}
74
+ try:
75
+ response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
76
+ if response.status_code == 200:
77
+ result = response.json()
78
+ st.session_state.processed_data = result['data']
79
+ st.session_state.processed_columns = result['columns']
80
+ st.success("File preprocessed successfully!")
81
+ st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
82
+ else:
83
+ st.error(f"Error during preprocessing: {response.json().get('detail', 'Unknown error')}")
84
+ except requests.exceptions.ConnectionError:
85
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
86
+ except Exception as e:
87
+ st.error(f"An unexpected error occurred: {e}")
88
+
89
+ # --- Causal Discovery Module ---
90
+ st.header("2. Causal Discovery 🕵️‍♂️")
91
+ if st.session_state.processed_data:
92
+ st.write("Learn the causal structure from your preprocessed data.")
93
+
94
+ discovery_algo = st.selectbox(
95
+ "Select Causal Discovery Algorithm:",
96
+ ("PC Algorithm", "GES (Greedy Equivalence Search) - Placeholder", "NOTEARS - Placeholder")
97
+ )
98
+
99
+ if st.button("Discover Causal Graph"):
100
+ st.info(f"Discovering graph using {discovery_algo}...")
101
+ algo_map = {
102
+ "PC Algorithm": "pc",
103
+ "GES (Greedy Equivalence Search) - Placeholder": "ges",
104
+ "NOTEARS - Placeholder": "notears"
105
+ }
106
+ selected_algo_code = algo_map[discovery_algo]
107
+
108
+ try:
109
+ response = requests.post(
110
+ f"{FLASK_API_URL}/discover/",
111
+ json={"data": st.session_state.processed_data, "algorithm": selected_algo_code}
112
+ )
113
+ if response.status_code == 200:
114
+ result = response.json()
115
+ st.session_state.causal_graph_adj = result['graph']
116
+ st.session_state.causal_graph_nodes = st.session_state.processed_columns
117
+ st.success("Causal graph discovered!")
118
+ st.subheader("Causal Graph Visualization")
119
+ # Visualization will be handled by the Causal Graph Visualizer section
120
+ else:
121
+ st.error(f"Error during causal discovery: {response.json().get('detail', 'Unknown error')}")
122
+ except requests.exceptions.ConnectionError:
123
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
124
+ except Exception as e:
125
+ st.error(f"An unexpected error occurred: {e}")
126
+ else:
127
+ st.info("Please preprocess data first to enable causal discovery.")
128
+
129
+ # --- Causal Graph Visualizer Module ---
130
+ st.header("3. Causal Graph Visualizer 📊")
131
+ if st.session_state.causal_graph_adj and st.session_state.causal_graph_nodes:
132
+ st.write("Interactive visualization of the discovered causal graph.")
133
+ try:
134
+ response = requests.post(
135
+ f"{FLASK_API_URL}/visualize/graph",
136
+ json={"graph": st.session_state.causal_graph_adj, "nodes": st.session_state.causal_graph_nodes}
137
+ )
138
+ if response.status_code == 200:
139
+ graph_json = response.json()['graph']
140
+ fig = go.Figure(json.loads(graph_json))
141
+ st.plotly_chart(fig, use_container_width=True)
142
+ st.markdown("""
143
+ **Graph Explanation:**
144
+ * **Nodes:** Represent variables in your dataset.
145
+ * **Arrows (Edges):** Indicate a direct causal influence from one variable (the tail) to another (the head).
146
+ * **No Arrow:** Suggests no direct causal relationship was found, or the relationship is mediated by other variables.
147
+
148
+ This graph helps answer "Why did it happen?" by showing the structural relationships.
149
+ """)
150
+ else:
151
+ st.error(f"Error visualizing graph: {response.json().get('detail', 'Unknown error')}")
152
+ except requests.exceptions.ConnectionError:
153
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
154
+ except Exception as e:
155
+ st.error(f"An unexpected error occurred during visualization: {e}")
156
+ else:
157
+ st.info("Please discover a causal graph first to visualize it.")
158
+
159
+
160
+ # --- Do-Calculus Engine Module ---
161
+ st.header("4. Do-Calculus Engine 🧪")
162
+ if st.session_state.processed_data and st.session_state.causal_graph_adj:
163
+ st.write("Simulate interventions and observe their effects based on the causal graph.")
164
+
165
+ intervention_var = st.selectbox(
166
+ "Select variable to intervene on:",
167
+ st.session_state.processed_columns,
168
+ key="inter_var_select"
169
+ )
170
+ # Attempt to infer type for intervention_value input
171
+ # Simplified approach: assuming numerical for now due to preprocessor output
172
+ if intervention_var and isinstance(st.session_state.processed_data[0][intervention_var], (int, float)):
173
+ intervention_value = st.number_input(f"Set '{intervention_var}' to value:", key="inter_val_input")
174
+ else: # Treat as string/categorical for input, then try to preprocess for API
175
+ intervention_value = st.text_input(f"Set '{intervention_var}' to value:", key="inter_val_input_text")
176
+ st.warning("Categorical intervention values might require specific encoding logic on the backend.")
177
+
178
+ if st.button("Perform Intervention"):
179
+ st.info(f"Performing intervention: do('{intervention_var}' = {intervention_value})...")
180
+ try:
181
+ response = requests.post(
182
+ f"{FLASK_API_URL}/intervene/",
183
+ json={
184
+ "data": st.session_state.processed_data,
185
+ "intervention_var": intervention_var,
186
+ "intervention_value": intervention_value,
187
+ "graph": st.session_state.causal_graph_adj # Pass graph for advanced do-calculus
188
+ }
189
+ )
190
+ if response.status_code == 200:
191
+ intervened_data = pd.DataFrame(response.json()['intervened_data'])
192
+ st.success("Intervention simulated successfully!")
193
+ st.subheader("Intervened Data (First 10 rows)")
194
+ st.dataframe(intervened_data.head(10))
195
+
196
+ # Simple comparison visualization (e.g., histogram of outcome variable)
197
+ if st.session_state.processed_columns and 'FinalExamScore' in st.session_state.processed_columns:
198
+ original_df = pd.DataFrame(st.session_state.processed_data)
199
+ fig_dist = go.Figure()
200
+ fig_dist.add_trace(go.Histogram(x=original_df['FinalExamScore'], name='Original', opacity=0.7))
201
+ fig_dist.add_trace(go.Histogram(x=intervened_data['FinalExamScore'], name='Intervened', opacity=0.0))
202
+
203
+ st.plotly_chart(fig_dist, use_container_width=True)
204
+ st.markdown("""
205
+ **Intervention Explanation:**
206
+ * By simulating `do(X=x)`, we are forcing the value of X, effectively breaking its causal links from its parents.
207
+ * The graph above shows the distribution of a key outcome variable (e.g., `FinalExamScore`) before and after the intervention.
208
+ * This helps answer "What if we do this instead?" by showing the predicted outcome.
209
+ """)
210
+ else:
211
+ st.info("Consider adding a relevant outcome variable to your dataset for better intervention analysis.")
212
+ else:
213
+ st.error(f"Error during intervention: {response.json().get('detail', 'Unknown error')}")
214
+ except requests.exceptions.ConnectionError:
215
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
216
+ except Exception as e:
217
+ st.error(f"An unexpected error occurred during intervention: {e}")
218
+ else:
219
+ st.info("Please preprocess data and discover a causal graph first to perform interventions.")
220
+
221
+ # --- Treatment Effect Estimator Module ---
222
+ st.header("5. Treatment Effect Estimator 🎯")
223
+ if st.session_state.processed_data:
224
+ st.write("Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).")
225
+
226
+ col1, col2 = st.columns(2)
227
+ with col1:
228
+ treatment_col = st.selectbox(
229
+ "Select Treatment Variable:",
230
+ st.session_state.processed_columns,
231
+ key="treat_col_select"
232
+ )
233
+ with col2:
234
+ outcome_col = st.selectbox(
235
+ "Select Outcome Variable:",
236
+ st.session_state.processed_columns,
237
+ key="outcome_col_select"
238
+ )
239
+
240
+ all_cols_except_treat_outcome = [col for col in st.session_state.processed_columns if col not in [treatment_col, outcome_col]]
241
+ covariates = st.multiselect(
242
+ "Select Covariates (confounders):",
243
+ all_cols_except_treat_outcome,
244
+ default=all_cols_except_treat_outcome, # Default to all other columns
245
+ key="covariates_select"
246
+ )
247
+
248
+ estimation_method = st.selectbox(
249
+ "Select Estimation Method:",
250
+ (
251
+ "Linear Regression ATE",
252
+ "Propensity Score Matching - Placeholder",
253
+ "Inverse Propensity Weighting - Placeholder",
254
+ "T-learner - Placeholder",
255
+ "S-learner - Placeholder"
256
+ )
257
+ )
258
+
259
+ if st.button("Estimate Treatment Effect"):
260
+ st.info(f"Estimating treatment effect using {estimation_method}...")
261
+ method_map = {
262
+ "Linear Regression ATE": "linear_regression",
263
+ "Propensity Score Matching - Placeholder": "propensity_score_matching",
264
+ "Inverse Propensity Weighting - Placeholder": "inverse_propensity_weighting",
265
+ "T-learner - Placeholder": "t_learner",
266
+ "S-learner - Placeholder": "s_learner"
267
+ }
268
+ selected_method_code = method_map[estimation_method]
269
+
270
+ try:
271
+ response = requests.post(
272
+ f"{FLASK_API_URL}/treatment/estimate_ate",
273
+ json={
274
+ "data": st.session_state.processed_data,
275
+ "treatment_col": treatment_col,
276
+ "outcome_col": outcome_col,
277
+ "covariates": covariates,
278
+ "method": selected_method_code
279
+ }
280
+ )
281
+ if response.status_code == 200:
282
+ ate_result = response.json()['result']
283
+ st.success(f"Treatment effect estimated using {estimation_method}:")
284
+ st.write(f"**Estimated ATE: {ate_result:.4f}**")
285
+ st.markdown("""
286
+ **Treatment Effect Explanation:**
287
+ * **Average Treatment Effect (ATE):** Measures the average causal effect of a treatment (e.g., `StudyHours`) on an outcome (e.g., `FinalExamScore`) across the entire population.
288
+ * It answers "How much does doing X cause a change in Y?".
289
+ * This estimation attempts to control for confounders (variables that influence both treatment and outcome) to isolate the true causal effect.
290
+ """)
291
+ else:
292
+ st.error(f"Error during ATE estimation: {response.json().get('detail', 'Unknown error')}")
293
+ except requests.exceptions.ConnectionError:
294
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
295
+ except Exception as e:
296
+ st.error(f"An unexpected error occurred during ATE estimation: {e}")
297
+ else:
298
+ st.info("Please preprocess data first to estimate treatment effects.")
299
+
300
+ # --- Prediction Module ---
301
+ st.header("6. Prediction Module 📈")
302
+ if st.session_state.processed_data:
303
+ st.write("Train a machine learning model for prediction (Regression or Classification).")
304
+
305
+ prediction_type = st.selectbox(
306
+ "Select Prediction Type:",
307
+ ("Regression", "Classification"),
308
+ key="prediction_type_select"
309
+ )
310
+
311
+ all_columns = st.session_state.processed_columns
312
+
313
+ suitable_target_columns = []
314
+ if st.session_state.processed_data:
315
+ temp_df = pd.DataFrame(st.session_state.processed_data)
316
+ for col in all_columns:
317
+ # For classification, check if column is object type (string), boolean,
318
+ # or has a limited number of unique integer values (e.g., less than 20 unique values)
319
+ if prediction_type == 'Classification':
320
+ if temp_df[col].dtype == 'object' or temp_df[col].dtype == 'bool':
321
+ suitable_target_columns.append(col)
322
+ elif pd.api.types.is_integer_dtype(temp_df[col]) and temp_df[col].nunique() < 20: # Heuristic for discrete integers
323
+ suitable_target_columns.append(col)
324
+ # For regression, primarily numerical columns
325
+ elif prediction_type == 'Regression':
326
+ if pd.api.types.is_numeric_dtype(temp_df[col]):
327
+ suitable_target_columns.append(col)
328
+
329
+ if not suitable_target_columns:
330
+ st.warning(f"No suitable target columns found for {prediction_type}. Please check your data types.")
331
+ target_col = None # Set to None to prevent error if no columns are found
332
+ else:
333
+ # Try to pre-select the currently chosen target_col if it's still suitable
334
+ # Otherwise, default to the first suitable column
335
+ if 'target_col_select' in st.session_state and st.session_state.target_col_select in suitable_target_columns:
336
+ default_target_index = suitable_target_columns.index(st.session_state.target_col_select)
337
+ else:
338
+ default_target_index = 0
339
+
340
+ target_col = st.selectbox(
341
+ "Select Target Variable:",
342
+ suitable_target_columns,
343
+ index=default_target_index,
344
+ key="target_col_select"
345
+ )
346
+
347
+ # Filter out the target column from feature options
348
+ feature_options = [col for col in all_columns if col != target_col]
349
+ feature_cols = st.multiselect(
350
+ "Select Feature Variables:",
351
+ feature_options,
352
+ default=feature_options, # Default to all other columns
353
+ key="feature_cols_select"
354
+ )
355
+
356
+ if st.button("Train Model & Predict", key="train_predict_button"):
357
+ if not target_col or not feature_cols:
358
+ st.warning("Please select a target variable and at least one feature variable.")
359
+ else:
360
+ st.info(f"Training {prediction_type} model using Random Forest...")
361
+ try:
362
+ response = requests.post(
363
+ f"{FLASK_API_URL}/prediction/train_predict",
364
+ json={
365
+ "data": st.session_state.processed_data,
366
+ "target_col": target_col,
367
+ "feature_cols": feature_cols,
368
+ "prediction_type": prediction_type.lower()
369
+ }
370
+ )
371
+
372
+ if response.status_code == 200:
373
+ results = response.json()['results']
374
+ st.success(f"{prediction_type} Model Trained Successfully!")
375
+ st.subheader("Model Performance")
376
+
377
+ if prediction_type == 'Regression':
378
+ st.write(f"**R-squared:** {results['r2_score']:.4f}")
379
+ st.write(f"**Mean Squared Error (MSE):** {results['mean_squared_error']:.4f}")
380
+ st.write(f"**Root Mean Squared Error (RMSE):** {results['root_mean_squared_error']:.4f}")
381
+
382
+ st.subheader("Actual vs. Predicted Plot")
383
+ actual_predicted_df = pd.DataFrame(results['actual_vs_predicted'])
384
+ fig_reg = px.scatter(actual_predicted_df, x='Actual', y='Predicted',
385
+ title='Actual vs. Predicted Values',
386
+ labels={'Actual': f'Actual {target_col}', 'Predicted': f'Predicted {target_col}'})
387
+ fig_reg.add_trace(go.Scatter(x=[actual_predicted_df['Actual'].min(), actual_predicted_df['Actual'].max()],
388
+ y=[actual_predicted_df['Actual'].min(), actual_predicted_df['Actual'].max()],
389
+ mode='lines', name='Ideal Fit', line=dict(dash='dash', color='red')))
390
+ st.plotly_chart(fig_reg, use_container_width=True)
391
+
392
+ st.subheader("Residual Plot")
393
+ actual_predicted_df['Residuals'] = actual_predicted_df['Actual'] - actual_predicted_df['Predicted']
394
+ fig_res = px.scatter(actual_predicted_df, x='Predicted', y='Residuals',
395
+ title='Residual Plot',
396
+ labels={'Predicted': f'Predicted {target_col}', 'Residuals': 'Residuals'})
397
+ fig_res.add_hline(y=0, line_dash="dash", line_color="red")
398
+ st.plotly_chart(fig_res, use_container_width=True)
399
+
400
+ elif prediction_type == 'Classification':
401
+ st.write(f"**Accuracy:** {results['accuracy']:.4f}")
402
+ st.write(f"**Precision (weighted):** {results['precision']:.4f}")
403
+ st.write(f"**Recall (weighted):** {results['recall']:.4f}")
404
+ st.write(f"**F1-Score (weighted):** {results['f1_score']:.4f}")
405
+
406
+ st.subheader("Confusion Matrix")
407
+ conf_matrix = results['confusion_matrix']
408
+ class_labels = results.get('class_labels', [str(i) for i in range(len(conf_matrix))])
409
+ fig_cm = px.imshow(conf_matrix,
410
+ labels=dict(x="Predicted", y="True", color="Count"),
411
+ x=class_labels,
412
+ y=class_labels,
413
+ text_auto=True,
414
+ color_continuous_scale="Viridis",
415
+ title="Confusion Matrix")
416
+ st.plotly_chart(fig_cm, use_container_width=True)
417
+
418
+ st.subheader("Classification Report")
419
+ # Convert dict to DataFrame for nice display
420
+ report_df = pd.DataFrame(results['classification_report']).transpose()
421
+ st.dataframe(report_df)
422
+
423
+ st.subheader("Feature Importances")
424
+ feature_importances_df = pd.DataFrame(list(results['feature_importances'].items()), columns=['Feature', 'Importance'])
425
+ fig_fi = px.bar(feature_importances_df, x='Importance', y='Feature', orientation='h',
426
+ title='Feature Importances',
427
+ labels={'Importance': 'Importance Score', 'Feature': 'Feature Name'})
428
+ fig_fi.update_layout(yaxis={'categoryorder':'total ascending'}) # Sort bars
429
+ st.plotly_chart(fig_fi, use_container_width=True)
430
+ else:
431
+ st.error(f"Error during prediction: {response.json().get('detail', 'Unknown error')}")
432
+ except requests.exceptions.ConnectionError:
433
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
434
+ except Exception as e:
435
+ st.error(f"An unexpected error occurred during prediction: {e}")
436
+ else:
437
+ st.info("Please preprocess data first to use the Prediction Module.")
438
+
439
+ # --- Time Series Causal Discovery Module ---
440
+ st.header("7. Time Series Causal Discovery ⏰")
441
+ if st.session_state.processed_data:
442
+ st.write("Infer causal relationships in time-series data using Granger Causality.")
443
+ st.info("Ensure your dataset includes a timestamp column and that variables are numeric.")
444
+
445
+ all_columns = st.session_state.processed_columns
446
+
447
+ # Heuristic to suggest potential timestamp columns (object/string type, or first column)
448
+ potential_ts_cols = [col for col in all_columns if pd.DataFrame(st.session_state.processed_data)[col].dtype == 'object']
449
+ if not potential_ts_cols and all_columns: # If no object columns, suggest the first column
450
+ potential_ts_cols = [all_columns[0]]
451
+
452
+ timestamp_col = st.selectbox(
453
+ "Select Timestamp Column:",
454
+ potential_ts_cols if potential_ts_cols else ["No suitable timestamp column found. Please check data."],
455
+ key="ts_col_select"
456
+ )
457
+
458
+ # Filter out timestamp column and non-numeric columns for analysis
459
+ variables_for_ts_analysis = [
460
+ col for col in all_columns if col != timestamp_col and pd.api.types.is_numeric_dtype(pd.DataFrame(st.session_state.processed_data)[col])
461
+ ]
462
+
463
+ variables_to_analyze = st.multiselect(
464
+ "Select Variables to Analyze for Granger Causality:",
465
+ variables_for_ts_analysis,
466
+ default=variables_for_ts_analysis,
467
+ key="ts_vars_select"
468
+ )
469
+
470
+ max_lags = st.number_input(
471
+ "Max Lags (for Granger Causality):",
472
+ min_value=1,
473
+ value=5, # Default value
474
+ step=1,
475
+ help="The maximum number of lagged observations to consider for causality."
476
+ )
477
+
478
+ if st.button("Discover Time Series Causality", key="ts_discover_button"):
479
+ if not timestamp_col or not variables_to_analyze:
480
+ st.warning("Please select a timestamp column and at least one variable to analyze.")
481
+ elif "No suitable timestamp column found" in timestamp_col:
482
+ st.error("Cannot proceed. Please ensure your data has a suitable timestamp column.")
483
+ else:
484
+ st.info("Performing Granger Causality tests...")
485
+ try:
486
+ response = requests.post(
487
+ f"{FLASK_API_URL}/timeseries/discover_causality",
488
+ json={
489
+ "data": st.session_state.processed_data,
490
+ "timestamp_col": timestamp_col,
491
+ "variables_to_analyze": variables_to_analyze,
492
+ "max_lags": max_lags
493
+ }
494
+ )
495
+
496
+ if response.status_code == 200:
497
+ results = response.json()['results']
498
+ st.success("Time Series Causal Discovery Complete!")
499
+ st.subheader("Granger Causality Test Results")
500
+
501
+ if results:
502
+ # Convert results to a DataFrame for better display
503
+ results_df = pd.DataFrame(results)
504
+ results_df['p_value'] = results_df['p_value'].round(4) # Round p-values
505
+ st.dataframe(results_df)
506
+
507
+ st.markdown("**Interpretation:** A small p-value (typically < 0.05) suggests that the 'cause' variable Granger-causes the 'effect' variable. This means past values of the 'cause' variable help predict future values of the 'effect' variable, even when past values of the 'effect' variable are considered.")
508
+ st.markdown(f"*(Note: Granger Causality implies predictive causality, not necessarily true mechanistic causality. Also, ensure your time series are stationary for robust results.)*")
509
+
510
+ # Optionally, visualize a simple causality graph
511
+ st.subheader("Granger Causality Graph")
512
+ fig_ts_graph = go.Figure()
513
+ nodes = []
514
+ edges = []
515
+ edge_colors = []
516
+
517
+ # Add nodes
518
+ for i, var in enumerate(variables_to_analyze):
519
+ nodes.append(dict(id=var, label=var, x=np.cos(i*2*np.pi/len(variables_to_analyze)), y=np.sin(i*2*np.pi/len(variables_to_analyze))))
520
+
521
+ # Add edges
522
+ for res in results:
523
+ if res['p_value'] < 0.05: # Consider it a causal link if p-value is below significance
524
+ edges.append(dict(source=res['cause'], target=res['effect'], value=1/res['p_value'], title=f"p={res['p_value']:.4f}"))
525
+ edge_colors.append("blue")
526
+ else:
527
+ # Optional: Show non-significant edges in a different color or omit
528
+ pass
529
+
530
+ # Use a simple network graph layout (Spring layout is common)
531
+ # For a truly interactive graph, you might need a different library or more complex Plotly setup
532
+ # This is a very basic attempt to visualize; consider more robust solutions like NetworkX + Plotly/Dash
533
+
534
+ # Simple way to draw arrows for significant relationships
535
+ significant_edges = [edge for edge in results if edge['p_value'] < 0.05]
536
+ if significant_edges:
537
+ st.write("Visualizing significant (p < 0.05) Granger causal links:")
538
+ # This needs a more robust way to draw directed edges in plotly if using just scatter/lines.
539
+ # For now, let's just list them clearly.
540
+ for edge in significant_edges:
541
+ st.write(f"➡️ **{edge['cause']}** Granger-causes **{edge['effect']}** (p={edge['p_value']:.4f})")
542
+ else:
543
+ st.info("No significant Granger causal links found at p < 0.05.")
544
+
545
+ else:
546
+ st.info("No Granger Causality relationships found or data insufficient.")
547
+
548
+ else:
549
+ st.error(f"Error during time-series causal discovery: {response.json().get('detail', 'Unknown error')}")
550
+ except requests.exceptions.ConnectionError:
551
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
552
+ except Exception as e:
553
+ st.error(f"An unexpected error occurred during time-series causal discovery: {e}")
554
+ else:
555
+ st.info("Please preprocess data first to use the Time Series Causal Discovery Module.")
556
+
557
+ # --- CausalBox Chat Assistant ---
558
+ st.header("8. CausalBox Chat Assistant 🤖")
559
+ st.write("Ask questions about your loaded dataset, causal concepts, or the discovered causal graph!")
560
+
561
+ # Initialize chat history in session state
562
+ if "messages" not in st.session_state:
563
+ st.session_state.messages = []
564
+
565
+ # Display chat messages from history on app rerun
566
+ for message in st.session_state.messages:
567
+ with st.chat_message(message["role"]):
568
+ st.markdown(message["content"])
569
+
570
+ # Accept user input
571
+ if prompt := st.chat_input("Ask me anything about CausalBox..."):
572
+ # Add user message to chat history
573
+ st.session_state.messages.append({"role": "user", "content": prompt})
574
+ # Display user message in chat message container
575
+ with st.chat_message("user"):
576
+ st.markdown(prompt)
577
+
578
+ # Prepare session context to send to the backend
579
+ session_context = {
580
+ "processed_data": st.session_state.processed_data,
581
+ "processed_columns": st.session_state.processed_columns,
582
+ "causal_graph_adj": st.session_state.causal_graph_adj,
583
+ "causal_graph_nodes": st.session_state.causal_graph_nodes,
584
+ # Add any other relevant session state variables that the chatbot might need
585
+ }
586
+
587
+ with st.spinner("Thinking..."):
588
+ try:
589
+ response = requests.post(
590
+ f"{FLASK_API_URL}/chatbot/message",
591
+ json={
592
+ "user_message": prompt,
593
+ "session_context": session_context
594
+ }
595
+ )
596
+
597
+ if response.status_code == 200:
598
+ chatbot_response_text = response.json().get('response', 'Sorry, I could not generate a response.')
599
+ else:
600
+ chatbot_response_text = f"Error from chatbot backend: {response.json().get('detail', 'Unknown error')}"
601
+ except requests.exceptions.ConnectionError:
602
+ chatbot_response_text = f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running."
603
+ except Exception as e:
604
+ chatbot_response_text = f"An unexpected error occurred while getting chatbot response: {e}"
605
+
606
+ # Display assistant response in chat message container
607
+ with st.chat_message("assistant"):
608
+ st.markdown(chatbot_response_text)
609
+ # Add assistant response to chat history
610
+ st.session_state.messages.append({"role": "assistant", "content": chatbot_response_text})
611
+
612
+ # --- Future Work (Simplified) ---
613
+ st.header("Future Work 🚀")
614
+ st.markdown("""
615
+ - **🔄 Auto-causal graph refresh:** Monitor dataset updates and automatically refresh the causal graph.
616
+ """)
617
+
618
+ st.markdown("---")
619
  st.info("Developed by CausalBox Team. For support, please contact us.")
utils/__pycache__/casual_algorithms.cpython-310.pyc CHANGED
Binary files a/utils/__pycache__/casual_algorithms.cpython-310.pyc and b/utils/__pycache__/casual_algorithms.cpython-310.pyc differ
 
utils/__pycache__/causal_chatbot.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
utils/__pycache__/do_calculus.cpython-310.pyc CHANGED
Binary files a/utils/__pycache__/do_calculus.cpython-310.pyc and b/utils/__pycache__/do_calculus.cpython-310.pyc differ
 
utils/__pycache__/graph_utils.cpython-310.pyc CHANGED
Binary files a/utils/__pycache__/graph_utils.cpython-310.pyc and b/utils/__pycache__/graph_utils.cpython-310.pyc differ
 
utils/__pycache__/prediction_models.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
utils/__pycache__/preprocessor.cpython-310.pyc CHANGED
Binary files a/utils/__pycache__/preprocessor.cpython-310.pyc and b/utils/__pycache__/preprocessor.cpython-310.pyc differ
 
utils/__pycache__/time_series_causal.cpython-310.pyc ADDED
Binary file (2.21 kB). View file
 
utils/__pycache__/treatment_effects.cpython-310.pyc CHANGED
Binary files a/utils/__pycache__/treatment_effects.cpython-310.pyc and b/utils/__pycache__/treatment_effects.cpython-310.pyc differ
 
utils/casual_algorithms.py CHANGED
@@ -1,64 +1,64 @@
1
- # utils/causal_algorithms.py
2
- import networkx as nx
3
- import pandas as pd
4
- import numpy as np
5
- from causallearn.search.ConstraintBased.PC import pc
6
- # from causallearn.search.ScoreBased.GES import ges # Example import for GES
7
- # from notears import notears_linear # Example import for NOTEARS
8
-
9
- class CausalDiscoveryAlgorithms:
10
- def pc_algorithm(self, df, alpha=0.05):
11
- """
12
- Run PC algorithm to learn causal graph.
13
- Returns a directed graph's adjacency matrix.
14
- Requires numerical data.
15
- """
16
- data_array = df.to_numpy()
17
- cg = pc(data_array, alpha=alpha, indep_test="fisherz")
18
- adj_matrix = cg.G.graph
19
- return adj_matrix
20
-
21
- def ges_algorithm(self, df):
22
- """
23
- Placeholder for GES (Greedy Equivalence Search) algorithm.
24
- Returns a directed graph's adjacency matrix.
25
- You would implement or integrate the GES algorithm here.
26
- """
27
- # Example: G, edges = ges(data_array)
28
- # For now, returning a simplified correlation-based graph for demonstration
29
- print("GES algorithm is a placeholder. Using a simplified correlation-based graph.")
30
- G = nx.DiGraph()
31
- nodes = df.columns
32
- G.add_nodes_from(nodes)
33
- corr_matrix = df.corr().abs()
34
- threshold = 0.3
35
- for i, col1 in enumerate(nodes):
36
- for col2 in nodes[i+1:]:
37
- if corr_matrix.loc[col1, col2] > threshold:
38
- if np.random.rand() > 0.5:
39
- G.add_edge(col1, col2)
40
- else:
41
- G.add_edge(col2, col1)
42
- return nx.to_numpy_array(G) # Convert to adjacency matrix
43
-
44
- def notears_algorithm(self, df):
45
- """
46
- Placeholder for NOTEARS algorithm.
47
- Returns a directed graph's adjacency matrix.
48
- You would implement or integrate the NOTEARS algorithm here.
49
- """
50
- # Example: W_est = notears_linear(data_array)
51
- print("NOTEARS algorithm is a placeholder. Using a simplified correlation-based graph.")
52
- G = nx.DiGraph()
53
- nodes = df.columns
54
- G.add_nodes_from(nodes)
55
- corr_matrix = df.corr().abs()
56
- threshold = 0.3
57
- for i, col1 in enumerate(nodes):
58
- for col2 in nodes[i+1:]:
59
- if corr_matrix.loc[col1, col2] > threshold:
60
- if np.random.rand() > 0.5:
61
- G.add_edge(col1, col2)
62
- else:
63
- G.add_edge(col2, col1)
64
  return nx.to_numpy_array(G) # Convert to adjacency matrix
 
1
+ # utils/causal_algorithms.py
2
+ import networkx as nx
3
+ import pandas as pd
4
+ import numpy as np
5
+ from causallearn.search.ConstraintBased.PC import pc
6
+ # from causallearn.search.ScoreBased.GES import ges # Example import for GES
7
+ # from notears import notears_linear # Example import for NOTEARS
8
+
9
+ class CausalDiscoveryAlgorithms:
10
+ def pc_algorithm(self, df, alpha=0.05):
11
+ """
12
+ Run PC algorithm to learn causal graph.
13
+ Returns a directed graph's adjacency matrix.
14
+ Requires numerical data.
15
+ """
16
+ data_array = df.to_numpy()
17
+ cg = pc(data_array, alpha=alpha, indep_test="fisherz")
18
+ adj_matrix = cg.G.graph
19
+ return adj_matrix
20
+
21
+ def ges_algorithm(self, df):
22
+ """
23
+ Placeholder for GES (Greedy Equivalence Search) algorithm.
24
+ Returns a directed graph's adjacency matrix.
25
+ You would implement or integrate the GES algorithm here.
26
+ """
27
+ # Example: G, edges = ges(data_array)
28
+ # For now, returning a simplified correlation-based graph for demonstration
29
+ print("GES algorithm is a placeholder. Using a simplified correlation-based graph.")
30
+ G = nx.DiGraph()
31
+ nodes = df.columns
32
+ G.add_nodes_from(nodes)
33
+ corr_matrix = df.corr().abs()
34
+ threshold = 0.3
35
+ for i, col1 in enumerate(nodes):
36
+ for col2 in nodes[i+1:]:
37
+ if corr_matrix.loc[col1, col2] > threshold:
38
+ if np.random.rand() > 0.5:
39
+ G.add_edge(col1, col2)
40
+ else:
41
+ G.add_edge(col2, col1)
42
+ return nx.to_numpy_array(G) # Convert to adjacency matrix
43
+
44
+ def notears_algorithm(self, df):
45
+ """
46
+ Placeholder for NOTEARS algorithm.
47
+ Returns a directed graph's adjacency matrix.
48
+ You would implement or integrate the NOTEARS algorithm here.
49
+ """
50
+ # Example: W_est = notears_linear(data_array)
51
+ print("NOTEARS algorithm is a placeholder. Using a simplified correlation-based graph.")
52
+ G = nx.DiGraph()
53
+ nodes = df.columns
54
+ G.add_nodes_from(nodes)
55
+ corr_matrix = df.corr().abs()
56
+ threshold = 0.3
57
+ for i, col1 in enumerate(nodes):
58
+ for col2 in nodes[i+1:]:
59
+ if corr_matrix.loc[col1, col2] > threshold:
60
+ if np.random.rand() > 0.5:
61
+ G.add_edge(col1, col2)
62
+ else:
63
+ G.add_edge(col2, col1)
64
  return nx.to_numpy_array(G) # Convert to adjacency matrix
utils/causal_chatbot.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/causal_chatbot.py
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langchain_groq import ChatGroq
5
+ from langchain_core.tools import tool
6
+ from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+ from utils.preprocessor import summarize_dataframe_for_chatbot
9
+ from utils.graph_utils import get_graph_summary_for_chatbot
10
+ import pandas as pd
11
+
12
+ load_dotenv()
13
+
14
+ # Configure Groq API Key
15
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
16
+ if not GROQ_API_KEY:
17
+ print("ERROR: GROQ_API_KEY environment variable not set.")
18
+ raise ValueError("GROQ_API_KEY is required.")
19
+
20
+ # Debug: Print API key details
21
+ print(f"Loaded GROQ_API_KEY: {GROQ_API_KEY[:5]}...{GROQ_API_KEY[-5:]}")
22
+ print(f"API Key Length: {len(GROQ_API_KEY)}")
23
+
24
+ # Initialize the Groq model with LangChain
25
+ try:
26
+ model = ChatGroq(
27
+ model_name="llama-3.3-70b-versatile",
28
+ temperature=0.7,
29
+ groq_api_key=GROQ_API_KEY
30
+ )
31
+ except Exception as e:
32
+ print(f"Error configuring Groq API: {e}")
33
+ model = None
34
+
35
+ def assess_causal_compatibility(data_json: list) -> str:
36
+ """
37
+ Assesses the dataset's compatibility for causal inference analysis.
38
+
39
+ Args:
40
+ data_json: List of dictionaries representing the dataset.
41
+
42
+ Returns:
43
+ String describing the dataset's suitability for causal analysis.
44
+ """
45
+ if not data_json:
46
+ return "No dataset provided for compatibility assessment."
47
+
48
+ try:
49
+ df = pd.DataFrame(data_json)
50
+ num_rows, num_cols = df.shape
51
+ numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
52
+ categorical_cols = df.select_dtypes(include=['object', 'category']).columns
53
+ missing_values = df.isnull().sum().sum()
54
+
55
+ assessment = [
56
+ f"Dataset has {num_rows} rows and {num_cols} columns.",
57
+ f"Numeric columns ({len(numeric_cols)}): {', '.join(numeric_cols) if len(numeric_cols) > 0 else 'None'}.",
58
+ f"Categorical columns ({len(categorical_cols)}): {', '.join(categorical_cols) if len(categorical_cols) > 0 else 'None'}.",
59
+ f"Missing values: {missing_values}."
60
+ ]
61
+
62
+ # Causal compatibility insights
63
+ if num_cols < 3:
64
+ assessment.append("Warning: Dataset has fewer than 3 columns, which may limit causal analysis (e.g., no room for treatment, outcome, and confounders).")
65
+ if len(numeric_cols) == 0:
66
+ assessment.append("Warning: No numeric columns detected. Causal inference often requires numeric variables for treatment or outcome.")
67
+ if missing_values > 0:
68
+ assessment.append("Note: Missing values detected. Preprocessing (e.g., imputation) may be needed for accurate causal analysis.")
69
+ if len(numeric_cols) >= 2 and num_rows > 100:
70
+ assessment.append("Positive: Dataset has multiple numeric columns and sufficient rows, suitable for causal inference with proper preprocessing.")
71
+ else:
72
+ assessment.append("Note: Ensure at least two numeric columns (e.g., treatment and outcome) and sufficient data points for robust causal analysis.")
73
+
74
+ return "\n".join(assessment)
75
+ except Exception as e:
76
+ print(f"Error in assess_causal_compatibility: {e}")
77
+ return "Unable to assess dataset compatibility due to processing error."
78
+
79
+ # Define tools using LangChain's @tool decorator
80
+ @tool
81
+ def get_dataset_info() -> dict:
82
+ """
83
+ Provides summary information and causal compatibility assessment for the currently loaded dataset.
84
+ The dataset is provided by the backend session context.
85
+
86
+ Returns:
87
+ Dictionary containing the dataset summary and compatibility assessment.
88
+ """
89
+ return {"summary": "Dataset will be provided by session context"}
90
+
91
+ @tool
92
+ def get_causal_graph_info() -> dict:
93
+ """
94
+ Provides summary information about the currently discovered causal graph.
95
+ The graph data is provided by the backend session context.
96
+
97
+ Returns:
98
+ Dictionary containing the graph summary.
99
+ """
100
+ return {"summary": "Graph data will be provided by session context"}
101
+
102
+ # Bind tools to the model
103
+ tools = [get_dataset_info, get_causal_graph_info]
104
+ if model:
105
+ model_with_tools = model.bind_tools(tools)
106
+
107
+ def get_chatbot_response(user_message: str, session_context: dict) -> str:
108
+ """
109
+ Gets a response from the Groq chatbot, handling tool calls.
110
+
111
+ Args:
112
+ user_message: The message from the user.
113
+ session_context: Dictionary containing current session data
114
+ (e.g., processed_data, causal_graph_adj, causal_graph_nodes).
115
+
116
+ Returns:
117
+ The chatbot's response message.
118
+ """
119
+ if model is None:
120
+ return "Chatbot is not configured correctly. Please check Groq API key."
121
+
122
+ try:
123
+ # Create a prompt template to guide the model's behavior
124
+ prompt = ChatPromptTemplate.from_messages([
125
+ ("system", """You are CausalBox Assistant, an AI that helps users analyze datasets and causal graphs.
126
+ Use the provided tools to access dataset or graph information. Do NOT generate or guess parameters for tool calls; the backend will provide all necessary data (e.g., dataset or graph details).
127
+ For dataset queries (e.g., "read the dataset", "dataset compatibility"), call `get_dataset_info` without arguments.
128
+ For graph queries (e.g., "describe the causal graph"), call `get_causal_graph_info` without arguments.
129
+ For other questions (e.g., "what is a confounder?"), respond directly with clear, accurate explanations.
130
+
131
+ When you receive tool results, provide a comprehensive analysis and explanation to help the user understand their data and causal analysis possibilities.
132
+
133
+ Examples:
134
+ - User: "Tell me about the dataset" -> Call `get_dataset_info`.
135
+ - User: "Check dataset compatibility for causal analysis" -> Call `get_dataset_info`.
136
+ - User: "Describe the causal graph" -> Call `get_causal_graph_info`.
137
+ - User: "What is a confounder?" -> Respond: "A confounder is a variable that influences both the treatment and outcome, causing a spurious association."
138
+ """),
139
+ ("human", "{user_message}")
140
+ ])
141
+
142
+ # Chain the prompt with the model
143
+ chain = prompt | model_with_tools
144
+
145
+ # Log the user message and session context
146
+ print(f"Processing user message: {user_message}")
147
+ print(f"Session context keys: {list(session_context.keys())}")
148
+
149
+ # Invoke the chain with the user message
150
+ response = chain.invoke({"user_message": user_message})
151
+ print(f"Model response: {response}")
152
+
153
+ # Handle tool calls if present
154
+ if response.tool_calls:
155
+ tool_call = response.tool_calls[0]
156
+ function_name = tool_call["name"]
157
+ function_args = tool_call["args"]
158
+
159
+ print(f"Chatbot calling tool: {function_name} with args: {function_args}")
160
+
161
+ # Map session context to tool arguments
162
+ tool_output = {}
163
+ if function_name == "get_dataset_info":
164
+ data_json = session_context.get("processed_data", [])
165
+ if not isinstance(data_json, list) or not data_json:
166
+ print(f"Invalid or empty data_json: {data_json}")
167
+ return "Error: No valid dataset available."
168
+ tool_output = get_dataset_info.invoke({})
169
+ tool_output["summary"] = summarize_dataframe_for_chatbot(data_json)
170
+ tool_output["causal_compatibility"] = assess_causal_compatibility(data_json)
171
+ elif function_name == "get_causal_graph_info":
172
+ graph_adj = session_context.get("causal_graph_adj", [])
173
+ nodes = session_context.get("causal_graph_nodes", [])
174
+ if not graph_adj or not nodes:
175
+ print("No causal graph data available")
176
+ return "Error: No causal graph available."
177
+ tool_output = get_causal_graph_info.invoke({})
178
+ tool_output["summary"] = get_graph_summary_for_chatbot(graph_adj, nodes)
179
+ else:
180
+ print(f"Unknown tool: {function_name}")
181
+ return f"Error: Unknown tool {function_name}."
182
+
183
+ print(f"Tool output: {tool_output}")
184
+
185
+ # Create the tool output text
186
+ output_text = tool_output["summary"]
187
+ if tool_output.get("causal_compatibility"):
188
+ output_text += "\n\nCausal Compatibility Assessment:\n" + tool_output["causal_compatibility"]
189
+
190
+ # Create messages for the final response - FIXED VERSION
191
+ messages = [
192
+ HumanMessage(content=user_message),
193
+ AIMessage(content="", tool_calls=[tool_call]),
194
+ ToolMessage(content=output_text, tool_call_id=tool_call["id"])
195
+ ]
196
+
197
+ # Create a follow-up prompt to ensure the model provides a comprehensive response
198
+ follow_up_prompt = ChatPromptTemplate.from_messages([
199
+ ("system", """You are CausalBox Assistant. Based on the tool results, provide a comprehensive, helpful response to the user's question.
200
+ Explain the dataset characteristics, causal compatibility, and provide actionable insights for causal analysis.
201
+ Be specific about what the data shows and what causal analysis approaches would be suitable.
202
+ Always provide a complete response, not just acknowledgment."""),
203
+ ("human", "{original_question}"),
204
+ ("assistant", "I'll analyze the dataset information for you."),
205
+ ("human", "Here's the dataset analysis: {tool_results}\n\nPlease provide a comprehensive explanation of this data and its suitability for causal analysis.")
206
+ ])
207
+
208
+ # Get final response from the model with explicit prompting
209
+ print("Invoking model with tool response messages")
210
+ try:
211
+ final_chain = follow_up_prompt | model
212
+ final_response = final_chain.invoke({
213
+ "original_question": user_message,
214
+ "tool_results": output_text
215
+ })
216
+ print(f"Final response content: {final_response.content}")
217
+
218
+ if final_response.content and final_response.content.strip():
219
+ return final_response.content
220
+ else:
221
+ # Fallback response if model still returns empty
222
+ return create_fallback_response(output_text, user_message)
223
+
224
+ except Exception as e:
225
+ print(f"Error in final response generation: {e}")
226
+ return create_fallback_response(output_text, user_message)
227
+
228
+ else:
229
+ print("No tool calls, returning direct response")
230
+ if response.content and response.content.strip():
231
+ return response.content
232
+ else:
233
+ return "I'm ready to help you with causal analysis. Please ask me about your dataset, causal graphs, or any causal inference concepts you'd like to understand."
234
+
235
+ except Exception as e:
236
+ print(f"Error communicating with Groq: {e}")
237
+ return f"Sorry, I'm having trouble processing your request: {str(e)}"
238
+
239
+ def create_fallback_response(tool_output: str, user_message: str) -> str:
240
+ """
241
+ Creates a fallback response when the model returns empty content.
242
+ """
243
+ response_parts = ["Based on your dataset analysis:\n"]
244
+
245
+ if "Dataset Summary:" in tool_output:
246
+ response_parts.append("📊 **Dataset Overview:**")
247
+ summary_part = tool_output.split("Dataset Summary:")[1].split("Causal Compatibility Assessment:")[0]
248
+ response_parts.append(summary_part.strip())
249
+ response_parts.append("")
250
+
251
+ if "Causal Compatibility Assessment:" in tool_output:
252
+ response_parts.append("🔍 **Causal Analysis Compatibility:**")
253
+ compatibility_part = tool_output.split("Causal Compatibility Assessment:")[1]
254
+ response_parts.append(compatibility_part.strip())
255
+ response_parts.append("")
256
+
257
+ # Add specific insights based on the data
258
+ if "FinalExamScore" in tool_output:
259
+ response_parts.append("💡 **Key Insights for Causal Analysis:**")
260
+ response_parts.append("- Your dataset appears to be education-related with variables like FinalExamScore, StudyHours, and TuitionHours")
261
+ response_parts.append("- This is excellent for causal analysis as you can explore questions like:")
262
+ response_parts.append(" • Does increasing study hours causally improve exam scores?")
263
+ response_parts.append(" • What's the causal effect of tutoring (TuitionHours) on performance?")
264
+ response_parts.append(" • How does parental education influence student outcomes?")
265
+ response_parts.append("")
266
+ response_parts.append("🚀 **Next Steps:**")
267
+ response_parts.append("- Consider identifying your treatment variable (e.g., TuitionHours)")
268
+ response_parts.append("- Define your outcome variable (likely FinalExamScore)")
269
+ response_parts.append("- Identify potential confounders (ParentalEducation, SchoolType)")
270
+
271
+ return "\n".join(response_parts)
utils/do_calculus.py CHANGED
@@ -1,52 +1,52 @@
1
- # utils/do_calculus.py
2
- import pandas as pd
3
- import numpy as np
4
- import networkx as nx
5
-
6
- class DoCalculus:
7
- def __init__(self, graph):
8
- self.graph = graph
9
-
10
- def intervene(self, data, intervention_var, intervention_value):
11
- """
12
- Simulate do(X=x) intervention on a variable.
13
- Returns intervened DataFrame.
14
- This is a simplified implementation.
15
- """
16
- intervened_data = data.copy()
17
-
18
- # Direct intervention: set the value
19
- intervened_data[intervention_var] = intervention_value
20
-
21
- # Propagate effects (simplified linear model) - needs graph
22
- # For a true do-calculus, you'd prune the graph and re-estimate based on parents
23
- # For demonstration, this still uses a simplified propagation.
24
- try:
25
- # Ensure graph is connected and topological sort is possible
26
- if self.graph and not nx.is_directed_acyclic_graph(self.graph):
27
- print("Warning: Graph is not a DAG. Topological sort may fail or be incorrect for do-calculus.")
28
-
29
- # This simplified propagation is a conceptual placeholder
30
- for node in nx.topological_sort(self.graph):
31
- if node == intervention_var:
32
- continue # Do not propagate back to the intervened variable
33
-
34
- parents = list(self.graph.predecessors(node))
35
- if parents:
36
- # Very simplified linear model to show propagation
37
- # In reality, you'd use learned coefficients or structural equations
38
- combined_effect = np.zeros(len(intervened_data))
39
- for p in parents:
40
- if p in intervened_data.columns:
41
- # Use a fixed random coefficient for demonstration
42
- coeff = 0.5
43
- combined_effect += intervened_data[p].to_numpy() * coeff
44
-
45
- # Add a small random noise to simulate uncertainty
46
- intervened_data[node] += combined_effect + np.random.normal(0, 0.1, len(intervened_data))
47
- except Exception as e:
48
- print(f"Could not perform full propagation due to graph issues or simplification: {e}")
49
- # Fallback to direct intervention only if graph logic fails
50
- pass # The direct intervention `intervened_data[intervention_var] = intervention_value` is already applied
51
-
52
  return intervened_data
 
1
+ # utils/do_calculus.py
2
+ import pandas as pd
3
+ import numpy as np
4
+ import networkx as nx
5
+
6
+ class DoCalculus:
7
+ def __init__(self, graph):
8
+ self.graph = graph
9
+
10
+ def intervene(self, data, intervention_var, intervention_value):
11
+ """
12
+ Simulate do(X=x) intervention on a variable.
13
+ Returns intervened DataFrame.
14
+ This is a simplified implementation.
15
+ """
16
+ intervened_data = data.copy()
17
+
18
+ # Direct intervention: set the value
19
+ intervened_data[intervention_var] = intervention_value
20
+
21
+ # Propagate effects (simplified linear model) - needs graph
22
+ # For a true do-calculus, you'd prune the graph and re-estimate based on parents
23
+ # For demonstration, this still uses a simplified propagation.
24
+ try:
25
+ # Ensure graph is connected and topological sort is possible
26
+ if self.graph and not nx.is_directed_acyclic_graph(self.graph):
27
+ print("Warning: Graph is not a DAG. Topological sort may fail or be incorrect for do-calculus.")
28
+
29
+ # This simplified propagation is a conceptual placeholder
30
+ for node in nx.topological_sort(self.graph):
31
+ if node == intervention_var:
32
+ continue # Do not propagate back to the intervened variable
33
+
34
+ parents = list(self.graph.predecessors(node))
35
+ if parents:
36
+ # Very simplified linear model to show propagation
37
+ # In reality, you'd use learned coefficients or structural equations
38
+ combined_effect = np.zeros(len(intervened_data))
39
+ for p in parents:
40
+ if p in intervened_data.columns:
41
+ # Use a fixed random coefficient for demonstration
42
+ coeff = 0.5
43
+ combined_effect += intervened_data[p].to_numpy() * coeff
44
+
45
+ # Add a small random noise to simulate uncertainty
46
+ intervened_data[node] += combined_effect + np.random.normal(0, 0.1, len(intervened_data))
47
+ except Exception as e:
48
+ print(f"Could not perform full propagation due to graph issues or simplification: {e}")
49
+ # Fallback to direct intervention only if graph logic fails
50
+ pass # The direct intervention `intervened_data[intervention_var] = intervention_value` is already applied
51
+
52
  return intervened_data
utils/graph_utils.py CHANGED
@@ -1,60 +1,107 @@
1
- # utils/graph_utils.py
2
- import networkx as nx
3
- import plotly.graph_objects as go
4
-
5
- def visualize_graph(graph):
6
- """
7
- Visualize a causal graph using Plotly.
8
- Returns Plotly figure as JSON.
9
- """
10
- # Use a fixed seed for layout reproducibility (optional)
11
- pos = nx.spring_layout(graph, seed=42)
12
-
13
- edge_x, edge_y = [], []
14
- for edge in graph.edges():
15
- x0, y0 = pos[edge[0]]
16
- x1, y1 = pos[edge[1]]
17
- edge_x.extend([x0, x1, None])
18
- edge_y.extend([y0, y1, None])
19
-
20
- edge_trace = go.Scatter(
21
- x=edge_x, y=edge_y,
22
- line=dict(width=1, color='#888'),
23
- mode='lines',
24
- hoverinfo='none'
25
- )
26
-
27
- node_x, node_y = [], []
28
- for node in graph.nodes():
29
- x, y = pos[node]
30
- node_x.append(x)
31
- node_y.append(y)
32
-
33
- node_trace = go.Scatter(
34
- x=node_x, y=node_y,
35
- mode='markers+text',
36
- text=list(graph.nodes()),
37
- textposition='bottom center',
38
- marker=dict(size=15, color='lightblue', line=dict(width=2, color='DarkSlateGrey')),
39
- hoverinfo='text'
40
- )
41
-
42
- fig = go.Figure(
43
- data=[edge_trace, node_trace],
44
- layout=go.Layout(
45
- showlegend=False,
46
- hovermode='closest',
47
- margin=dict(b=20, l=5, r=5, t=40),
48
- annotations=[dict(
49
- text="Python Causal Graph",
50
- showarrow=False,
51
- xref="paper", yref="paper",
52
- x=0.005, y= -0.002,
53
- font=dict(size=14, color="lightgray")
54
- )],
55
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
56
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
57
- title=dict(text="Causal Graph Visualization", font=dict(size=16)) # Corrected line
58
- )
59
- )
60
- return fig.to_json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/graph_utils.py
2
+ import networkx as nx
3
+ import plotly.graph_objects as go
4
+ import numpy as np
5
+
6
+ def visualize_graph(graph):
7
+ """
8
+ Visualize a causal graph using Plotly.
9
+ Returns Plotly figure as JSON.
10
+ """
11
+ # Use a fixed seed for layout reproducibility (optional)
12
+ pos = nx.spring_layout(graph, seed=42)
13
+
14
+ edge_x, edge_y = [], []
15
+ for edge in graph.edges():
16
+ x0, y0 = pos[edge[0]]
17
+ x1, y1 = pos[edge[1]]
18
+ edge_x.extend([x0, x1, None])
19
+ edge_y.extend([y0, y1, None])
20
+
21
+ edge_trace = go.Scatter(
22
+ x=edge_x, y=edge_y,
23
+ line=dict(width=1, color='#888'),
24
+ mode='lines',
25
+ hoverinfo='none'
26
+ )
27
+
28
+ node_x, node_y = [], []
29
+ for node in graph.nodes():
30
+ x, y = pos[node]
31
+ node_x.append(x)
32
+ node_y.append(y)
33
+
34
+ node_trace = go.Scatter(
35
+ x=node_x, y=node_y,
36
+ mode='markers+text',
37
+ text=list(graph.nodes()),
38
+ textposition='bottom center',
39
+ marker=dict(size=15, color='lightblue', line=dict(width=2, color='DarkSlateGrey')),
40
+ hoverinfo='text'
41
+ )
42
+
43
+ fig = go.Figure(
44
+ data=[edge_trace, node_trace],
45
+ layout=go.Layout(
46
+ showlegend=False,
47
+ hovermode='closest',
48
+ margin=dict(b=20, l=5, r=5, t=40),
49
+ annotations=[dict(
50
+ text="Python Causal Graph",
51
+ showarrow=False,
52
+ xref="paper", yref="paper",
53
+ x=0.005, y= -0.002,
54
+ font=dict(size=14, color="lightgray")
55
+ )],
56
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
57
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
58
+ title=dict(text="Causal Graph Visualization", font=dict(size=16)) # Corrected line
59
+ )
60
+ )
61
+ return fig.to_json()
62
+
63
+ def get_graph_summary_for_chatbot(graph_adj, nodes):
64
+ """
65
+ Generates a text summary of the causal graph for the chatbot.
66
+ """
67
+ if not graph_adj or not nodes:
68
+ return "No causal graph discovered yet."
69
+
70
+ adj_matrix = np.array(graph_adj)
71
+ G = nx.DiGraph(adj_matrix)
72
+
73
+ # Relabel nodes with actual names
74
+ mapping = {i: node_name for i, node_name in enumerate(nodes)}
75
+ G = nx.relabel_nodes(G, mapping)
76
+
77
+ num_nodes = G.number_of_nodes()
78
+ num_edges = G.number_of_edges()
79
+
80
+ summary = (
81
+ f"The causal graph has {num_nodes} variables (nodes) and {num_edges} causal relationships (directed edges).\n"
82
+ "The variables are: " + ", ".join(nodes) + ".\n"
83
+ )
84
+
85
+ # Add some basic structural info
86
+ if nx.is_directed_acyclic_graph(G):
87
+ summary += "The graph is a Directed Acyclic Graph (DAG), which is typical for causal models.\n"
88
+ else:
89
+ summary += "The graph contains cycles, which might indicate feedback loops or issues with the discovery algorithm for a DAG model.\n"
90
+
91
+ # Smallest graphs: list all edges
92
+ if num_edges > 0 and num_edges < 10: # Avoid listing too many edges for large graphs
93
+ edge_list = [f"{u} -> {v}" for u, v in G.edges()]
94
+ summary += "The discovered relationships are: " + ", ".join(edge_list) + ".\n"
95
+ elif num_edges >= 10:
96
+ summary += "There are many edges; you can ask for specific relationships (e.g., 'What are the direct causes of X?').\n"
97
+
98
+ # Identify source and sink nodes (if any)
99
+ source_nodes = [n for n, d in G.in_degree() if d == 0]
100
+ sink_nodes = [n for n, d in G.out_degree() if d == 0]
101
+
102
+ if source_nodes:
103
+ summary += f"Variables with no known causes (source nodes): {', '.join(source_nodes)}.\n"
104
+ if sink_nodes:
105
+ summary += f"Variables with no known effects (sink nodes): {', '.join(sink_nodes)}.\n"
106
+
107
+ return summary
utils/prediction_models.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/prediction_models.py
2
+ import pandas as pd
3
+ from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
6
+ import numpy as np
7
+
8
+ def train_predict_random_forest(data_list, target_col, feature_cols, prediction_type='regression'):
9
+ """
10
+ Trains a Random Forest model and performs prediction/evaluation.
11
+
12
+ Args:
13
+ data_list (list of dict): List of dictionaries representing the dataset.
14
+ target_col (str): Name of the target variable.
15
+ feature_cols (list): List of names of feature variables.
16
+ prediction_type (str): 'regression' or 'classification'.
17
+
18
+ Returns:
19
+ dict: A dictionary containing model results (metrics, predictions, feature importances).
20
+ """
21
+ df = pd.DataFrame(data_list)
22
+
23
+ if not all(col in df.columns for col in feature_cols + [target_col]):
24
+ missing_cols = [col for col in feature_cols + [target_col] if col not in df.columns]
25
+ raise ValueError(f"Missing columns in data: {missing_cols}")
26
+
27
+ X = df[feature_cols]
28
+ y = df[target_col]
29
+
30
+ # Handle categorical features if any
31
+ X = pd.get_dummies(X, drop_first=True) # One-hot encode categorical features
32
+
33
+ # Split data for robust evaluation
34
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
35
+
36
+ results = {}
37
+
38
+ if prediction_type == 'regression':
39
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
40
+ model.fit(X_train, y_train)
41
+ y_pred = model.predict(X_test)
42
+
43
+ results['model_type'] = 'Regression'
44
+ results['r2_score'] = r2_score(y_test, y_pred)
45
+ results['mean_squared_error'] = mean_squared_error(y_test, y_pred)
46
+ results['root_mean_squared_error'] = np.sqrt(mean_squared_error(y_test, y_pred))
47
+ results['actual_vs_predicted'] = pd.DataFrame({'Actual': y_test, 'Predicted': y_pred}).to_dict(orient='list')
48
+
49
+ elif prediction_type == 'classification':
50
+ # Ensure target variable is suitable for classification (e.g., integer/categorical)
51
+ # You might need more robust handling for different target types here
52
+ if y.dtype == 'object' or y.dtype.name == 'category':
53
+ y_train = y_train.astype('category').cat.codes
54
+ y_test = y_test.astype('category').cat.codes
55
+ y_unique_labels = df[target_col].astype('category').cat.categories.tolist()
56
+ results['class_labels'] = y_unique_labels
57
+ else:
58
+ y_unique_labels = sorted(y.unique().tolist())
59
+ results['class_labels'] = y_unique_labels
60
+
61
+
62
+ model = RandomForestClassifier(n_estimators=100, random_state=42)
63
+ model.fit(X_train, y_train)
64
+ y_pred = model.predict(X_test)
65
+
66
+ results['model_type'] = 'Classification'
67
+ results['accuracy'] = accuracy_score(y_test, y_pred)
68
+
69
+ # Precision, Recall, F1-score - use 'weighted' average for multi-class
70
+ precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='weighted', zero_division=0)
71
+ results['precision'] = precision
72
+ results['recall'] = recall
73
+ results['f1_score'] = f1
74
+
75
+ results['confusion_matrix'] = confusion_matrix(y_test, y_pred).tolist()
76
+ results['classification_report'] = classification_report(y_test, y_pred, output_dict=True, zero_division=0)
77
+
78
+ else:
79
+ raise ValueError("prediction_type must be 'regression' or 'classification'")
80
+
81
+ # Feature Importance (common for both)
82
+ if hasattr(model, 'feature_importances_'):
83
+ feature_importances = pd.Series(model.feature_importances_, index=X.columns).sort_values(ascending=False)
84
+ results['feature_importances'] = feature_importances.to_dict()
85
+
86
+ return results
utils/preprocessor.py CHANGED
@@ -1,57 +1,88 @@
1
- # utils/preprocessor.py
2
- from sklearn.preprocessing import StandardScaler, LabelEncoder
3
- import pandas as pd
4
- import numpy as np
5
- import logging
6
-
7
- # Set up logging
8
- logging.basicConfig(level=logging.INFO)
9
- logger = logging.getLogger(__name__)
10
-
11
- class DataPreprocessor:
12
- def __init__(self):
13
- self.scaler = StandardScaler()
14
- self.label_encoders = {}
15
-
16
- def preprocess(self, df):
17
- """
18
- Preprocess DataFrame: handle missing values, encode categorical variables, scale numerical variables.
19
- """
20
- try:
21
- logger.info(f"Input DataFrame shape: {df.shape}, columns: {list(df.columns)}")
22
- df_processed = df.copy()
23
-
24
- # Handle missing values
25
- logger.info("Handling missing values...")
26
- for col in df_processed.columns:
27
- if df_processed[col].isnull().any():
28
- if pd.api.types.is_numeric_dtype(df_processed[col]):
29
- df_processed[col] = df_processed[col].fillna(df_processed[col].mean())
30
- logger.info(f"Filled numeric missing values in '{col}' with mean.")
31
- else:
32
- df_processed[col] = df_processed[col].fillna(df_processed[col].mode()[0])
33
- logger.info(f"Filled categorical missing values in '{col}' with mode.")
34
-
35
- # Encode categorical variables
36
- logger.info("Encoding categorical variables...")
37
- for col in df_processed.select_dtypes(include=['object', 'category']).columns:
38
- logger.info(f"Encoding column: {col}")
39
- self.label_encoders[col] = LabelEncoder()
40
- df_processed[col] = self.label_encoders[col].fit_transform(df_processed[col])
41
-
42
- # Scale numerical variables
43
- logger.info("Scaling numerical variables...")
44
- numeric_cols = df_processed.select_dtypes(include=[np.number]).columns
45
- if len(numeric_cols) > 0:
46
- # Exclude columns that are now effectively categorical (post-label encoding)
47
- # This is a heuristic; ideally, identify original numeric columns.
48
- cols_to_scale = [col for col in numeric_cols if col not in self.label_encoders]
49
- if cols_to_scale:
50
- df_processed[cols_to_scale] = self.scaler.fit_transform(df_processed[cols_to_scale])
51
- logger.info(f"Scaled numeric columns: {cols_to_scale}")
52
-
53
- logger.info(f"Preprocessed DataFrame shape: {df_processed.shape}")
54
- return df_processed
55
- except Exception as e:
56
- logger.exception(f"Error preprocessing data: {str(e)}")
57
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/preprocessor.py
2
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
3
+ import pandas as pd
4
+ import numpy as np
5
+ import logging
6
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
7
+ from sklearn.impute import SimpleImputer
8
+ from sklearn.compose import ColumnTransformer
9
+ from sklearn.pipeline import Pipeline
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class DataPreprocessor:
16
+ def __init__(self):
17
+ self.scaler = StandardScaler()
18
+ self.label_encoders = {}
19
+
20
+ def preprocess(self, df):
21
+ """
22
+ Preprocess DataFrame: handle missing values, encode categorical variables, scale numerical variables.
23
+ """
24
+ try:
25
+ logger.info(f"Input DataFrame shape: {df.shape}, columns: {list(df.columns)}")
26
+ df_processed = df.copy()
27
+
28
+ # Handle missing values
29
+ logger.info("Handling missing values...")
30
+ for col in df_processed.columns:
31
+ if df_processed[col].isnull().any():
32
+ if pd.api.types.is_numeric_dtype(df_processed[col]):
33
+ df_processed[col] = df_processed[col].fillna(df_processed[col].mean())
34
+ logger.info(f"Filled numeric missing values in '{col}' with mean.")
35
+ else:
36
+ df_processed[col] = df_processed[col].fillna(df_processed[col].mode()[0])
37
+ logger.info(f"Filled categorical missing values in '{col}' with mode.")
38
+
39
+ # Encode categorical variables
40
+ logger.info("Encoding categorical variables...")
41
+ for col in df_processed.select_dtypes(include=['object', 'category']).columns:
42
+ logger.info(f"Encoding column: {col}")
43
+ self.label_encoders[col] = LabelEncoder()
44
+ df_processed[col] = self.label_encoders[col].fit_transform(df_processed[col])
45
+
46
+ # Scale numerical variables
47
+ logger.info("Scaling numerical variables...")
48
+ numeric_cols = df_processed.select_dtypes(include=[np.number]).columns
49
+ if len(numeric_cols) > 0:
50
+ # Exclude columns that are now effectively categorical (post-label encoding)
51
+ # This is a heuristic; ideally, identify original numeric columns.
52
+ cols_to_scale = [col for col in numeric_cols if col not in self.label_encoders]
53
+ if cols_to_scale:
54
+ df_processed[cols_to_scale] = self.scaler.fit_transform(df_processed[cols_to_scale])
55
+ logger.info(f"Scaled numeric columns: {cols_to_scale}")
56
+
57
+ logger.info(f"Preprocessed DataFrame shape: {df_processed.shape}")
58
+ return df_processed
59
+ except Exception as e:
60
+ logger.exception(f"Error preprocessing data: {str(e)}")
61
+ raise
62
+
63
+ def summarize_dataframe_for_chatbot(data_list):
64
+ """
65
+ Generates a test summary of the DataFrame for chatbot interaction."""
66
+ if not data_list:
67
+ return "No data loaded."
68
+ df = pd.DataFrame(data_list)
69
+ nums_rows, num_cols = df.shape
70
+
71
+ col_info = []
72
+ for col in df.columns:
73
+ dtype = df[col].dtype
74
+ unique_vals = df[col].nunique()
75
+ missing_count = df[col].isnull().sum()
76
+
77
+ info = f"-{col} (Type:{dtype}"
78
+ if pd.api.types.is_numeric_dtype(df[col]):
79
+ info +=f", Min:{df[col].min():.2f}, Max:{df[col].max():.2f}"
80
+ else:
81
+ info += f", Unique:{unique_vals}"
82
+
83
+ if missing_count > 0:
84
+ info += f", Missing:{missing_count}"
85
+ info += ")"
86
+ col_info.append(info)
87
+ summary = (f"Dataset Summary:\n- Rows: {nums_rows}, Columns: {num_cols}\nColumns:\n" + "\n".join(col_info))
88
+ return summary
utils/time_series_causal.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/time_series_causal.py
2
+ import pandas as pd
3
+ from statsmodels.tsa.stattools import grangercausalitytests
4
+
5
+ def perform_granger_causality(data_list, timestamp_col, variables_to_analyze, max_lags=1):
6
+ """
7
+ Performs pairwise Granger Causality tests on the given time-series data.
8
+
9
+ Args:
10
+ data_list (list of dict): List of dictionaries representing the dataset.
11
+ timestamp_col (str): Name of the timestamp column.
12
+ variables_to_analyze (list): List of names of variables to test for causality.
13
+ max_lags (int): The maximum number of lags to use for the Granger causality test.
14
+
15
+ Returns:
16
+ list: A list of dictionaries, each describing a causal relationship found.
17
+ """
18
+ df = pd.DataFrame(data_list)
19
+
20
+ if timestamp_col not in df.columns:
21
+ raise ValueError(f"Timestamp column '{timestamp_col}' not found in data.")
22
+
23
+ # Ensure timestamp column is datetime and set as index
24
+ try:
25
+ df[timestamp_col] = pd.to_datetime(df[timestamp_col])
26
+ df = df.set_index(timestamp_col).sort_index()
27
+ except Exception as e:
28
+ raise ValueError(f"Could not convert timestamp column '{timestamp_col}' to datetime: {e}")
29
+
30
+ # Ensure all variables to analyze are numeric
31
+ for col in variables_to_analyze:
32
+ if not pd.api.types.is_numeric_dtype(df[col]):
33
+ raise ValueError(f"Variable '{col}' is not numeric. Granger Causality requires numeric variables.")
34
+ if df[col].isnull().any():
35
+ # Handle NaNs: Granger Causality tests require no NaN values.
36
+ # You might choose to drop rows with NaNs or impute.
37
+ # For simplicity, here we'll raise an error or drop them.
38
+ # print(f"Warning: Variable '{col}' contains NaN values. Rows with NaNs will be dropped.")
39
+ df = df.dropna(subset=[col])
40
+
41
+
42
+ # Select only the relevant columns
43
+ df_selected = df[variables_to_analyze]
44
+
45
+ # Granger Causality requires stationarity in theory.
46
+ # While statsmodels can run on non-stationary data, results should be interpreted cautiously.
47
+ # You might want to add differencing logic here (e.g., df.diff().dropna())
48
+ # or a warning for the user.
49
+ # For now, we proceed directly.
50
+
51
+ causal_results = []
52
+
53
+ # Iterate through all unique pairs of variables
54
+ for i in range(len(variables_to_analyze)):
55
+ for j in range(len(variables_to_analyze)):
56
+ if i == j:
57
+ continue # Skip self-causation tests
58
+
59
+ cause_var = variables_to_analyze[i]
60
+ effect_var = variables_to_analyze[j]
61
+
62
+ # Prepare data for grangercausalitytests: [effect_var, cause_var]
63
+ # grangercausalitytests takes a DataFrame where the first column is the dependent variable (effect)
64
+ # and the second column is the independent variable (cause)
65
+ data_for_test = df_selected[[effect_var, cause_var]]
66
+
67
+ if data_for_test.empty or len(data_for_test) <= max_lags:
68
+ # Not enough data points to perform test with specified lags
69
+ # This can happen if NaNs were dropped or dataset is too small
70
+ continue
71
+
72
+ try:
73
+ # Perform Granger Causality test
74
+ # The output is a dictionary. The key 'ssr_ftest' (or 'params_ftest')
75
+ # usually contains the p-value.
76
+ test_result = grangercausalitytests(data_for_test, max_lags, verbose=False)
77
+
78
+ # Extract p-value for the optimal lag or the test that interests you
79
+ # Commonly, F-test p-value for the last lag tested is used
80
+ # test_result is a dictionary where keys are lag numbers
81
+ # Each lag has a tuple of (test_statistics, p_values).
82
+ # (F-test, Chi2-test, LR-test, SSR-test) -> [statistic, p-value, df_denom, df_num]
83
+
84
+ # Let's consider the F-test for the last lag as a general indicator
85
+ last_lag_p_value = test_result[max_lags][0]['ssr_ftest'][1] # F-test p-value
86
+
87
+ causal_results.append({
88
+ "cause": cause_var,
89
+ "effect": effect_var,
90
+ "p_value": last_lag_p_value,
91
+ "test_type": "Granger Causality (F-test)",
92
+ "max_lags": max_lags
93
+ })
94
+ except ValueError as ve:
95
+ # Handle cases where the test cannot be performed (e.g., singular matrix)
96
+ print(f"Could not perform Granger Causality for {cause_var} -> {effect_var} with max_lags={max_lags}: {ve}")
97
+ continue # Skip this pair
98
+ except Exception as e:
99
+ print(f"An unexpected error occurred for {cause_var} -> {effect_var}: {e}")
100
+ continue
101
+
102
+ return causal_results
utils/treatment_effects.py CHANGED
@@ -1,63 +1,63 @@
1
- # utils/treatment_effects.py
2
- from sklearn.linear_model import LinearRegression, LogisticRegression
3
- import pandas as pd
4
- import numpy as np
5
- # For matching-based methods, you might need libraries like dowhy or causalml
6
- # import statsmodels.api as sm # Example for regression diagnostics
7
-
8
- class TreatmentEffectAlgorithms:
9
- def linear_regression_ate(self, df, treatment_col, outcome_col, covariates):
10
- """
11
- Estimate ATE using linear regression.
12
- """
13
- X = df[covariates + [treatment_col]]
14
- y = df[outcome_col]
15
- model = LinearRegression()
16
- model.fit(X, y)
17
- ate = model.coef_[-1] # Coefficient of treatment_col
18
- return float(ate)
19
-
20
- def propensity_score_matching(self, df, treatment_col, outcome_col, covariates):
21
- """
22
- Placeholder for Propensity Score Matching.
23
- You would implement or integrate a matching algorithm here.
24
- """
25
- print("Propensity Score Matching is a placeholder. Returning a dummy ATE.")
26
- # Simplified: Estimate propensity scores
27
- X_propensity = df[covariates]
28
- T_propensity = df[treatment_col]
29
- prop_model = LogisticRegression(solver='liblinear')
30
- prop_model.fit(X_propensity, T_propensity)
31
- propensity_scores = prop_model.predict_proba(X_propensity)[:, 1]
32
-
33
- # Dummy ATE calculation for demonstration
34
- treated_outcome = df[df[treatment_col] == 1][outcome_col].mean()
35
- control_outcome = df[df[treatment_col] == 0][outcome_col].mean()
36
- return float(treated_outcome - control_outcome) # Simplified dummy ATE
37
-
38
- def inverse_propensity_weighting(self, df, treatment_col, outcome_col, covariates):
39
- """
40
- Placeholder for Inverse Propensity Weighting (IPW).
41
- You would implement or integrate IPW here.
42
- """
43
- print("Inverse Propensity Weighting is a placeholder. Returning a dummy ATE.")
44
- # Dummy ATE for demonstration
45
- return np.random.rand() * 10 # Random dummy value
46
-
47
- def t_learner(self, df, treatment_col, outcome_col, covariates):
48
- """
49
- Placeholder for T-learner.
50
- You would implement a T-learner using two separate models.
51
- """
52
- print("T-learner is a placeholder. Returning a dummy ATE.")
53
- # Dummy ATE for demonstration
54
- return np.random.rand() * 10 + 5 # Random dummy value
55
-
56
- def s_learner(self, df, treatment_col, outcome_col, covariates):
57
- """
58
- Placeholder for S-learner.
59
- You would implement an S-learner using a single model.
60
- """
61
- print("S-learner is a placeholder. Returning a dummy ATE.")
62
- # Dummy ATE for demonstration
63
  return np.random.rand() * 10 - 2 # Random dummy value
 
1
+ # utils/treatment_effects.py
2
+ from sklearn.linear_model import LinearRegression, LogisticRegression
3
+ import pandas as pd
4
+ import numpy as np
5
+ # For matching-based methods, you might need libraries like dowhy or causalml
6
+ # import statsmodels.api as sm # Example for regression diagnostics
7
+
8
+ class TreatmentEffectAlgorithms:
9
+ def linear_regression_ate(self, df, treatment_col, outcome_col, covariates):
10
+ """
11
+ Estimate ATE using linear regression.
12
+ """
13
+ X = df[covariates + [treatment_col]]
14
+ y = df[outcome_col]
15
+ model = LinearRegression()
16
+ model.fit(X, y)
17
+ ate = model.coef_[-1] # Coefficient of treatment_col
18
+ return float(ate)
19
+
20
+ def propensity_score_matching(self, df, treatment_col, outcome_col, covariates):
21
+ """
22
+ Placeholder for Propensity Score Matching.
23
+ You would implement or integrate a matching algorithm here.
24
+ """
25
+ print("Propensity Score Matching is a placeholder. Returning a dummy ATE.")
26
+ # Simplified: Estimate propensity scores
27
+ X_propensity = df[covariates]
28
+ T_propensity = df[treatment_col]
29
+ prop_model = LogisticRegression(solver='liblinear')
30
+ prop_model.fit(X_propensity, T_propensity)
31
+ propensity_scores = prop_model.predict_proba(X_propensity)[:, 1]
32
+
33
+ # Dummy ATE calculation for demonstration
34
+ treated_outcome = df[df[treatment_col] == 1][outcome_col].mean()
35
+ control_outcome = df[df[treatment_col] == 0][outcome_col].mean()
36
+ return float(treated_outcome - control_outcome) # Simplified dummy ATE
37
+
38
+ def inverse_propensity_weighting(self, df, treatment_col, outcome_col, covariates):
39
+ """
40
+ Placeholder for Inverse Propensity Weighting (IPW).
41
+ You would implement or integrate IPW here.
42
+ """
43
+ print("Inverse Propensity Weighting is a placeholder. Returning a dummy ATE.")
44
+ # Dummy ATE for demonstration
45
+ return np.random.rand() * 10 # Random dummy value
46
+
47
+ def t_learner(self, df, treatment_col, outcome_col, covariates):
48
+ """
49
+ Placeholder for T-learner.
50
+ You would implement a T-learner using two separate models.
51
+ """
52
+ print("T-learner is a placeholder. Returning a dummy ATE.")
53
+ # Dummy ATE for demonstration
54
+ return np.random.rand() * 10 + 5 # Random dummy value
55
+
56
+ def s_learner(self, df, treatment_col, outcome_col, covariates):
57
+ """
58
+ Placeholder for S-learner.
59
+ You would implement an S-learner using a single model.
60
+ """
61
+ print("S-learner is a placeholder. Returning a dummy ATE.")
62
+ # Dummy ATE for demonstration
63
  return np.random.rand() * 10 - 2 # Random dummy value