ShutterStack commited on
Commit
f6aa1e3
·
verified ·
1 Parent(s): 2e39af3

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +633 -618
streamlit_app.py CHANGED
@@ -1,619 +1,634 @@
1
- # streamlit_app.py
2
- import streamlit as st
3
- import pandas as pd
4
- import requests
5
- import json
6
- import plotly.express as px
7
- import plotly.graph_objects as go
8
- import numpy as np # For random array in placeholders
9
- import os
10
-
11
- # Configuration
12
- FLASK_API_URL = "http://localhost:5000" # Ensure this matches your Flask app's host and port
13
-
14
- st.set_page_config(layout="wide", page_title="CausalBox Toolkit")
15
-
16
- st.title("🔬 CausalBox: A Causal Inference Toolkit")
17
- st.markdown("Uncover causal relationships, simulate interventions, and estimate treatment effects.")
18
-
19
- # --- Session State Initialization ---
20
- if 'processed_data' not in st.session_state:
21
- st.session_state.processed_data = None
22
- if 'processed_columns' not in st.session_state:
23
- st.session_state.processed_columns = None
24
- if 'causal_graph_adj' not in st.session_state:
25
- st.session_state.causal_graph_adj = None
26
- if 'causal_graph_nodes' not in st.session_state:
27
- st.session_state.causal_graph_nodes = None
28
-
29
- # --- Data Preprocessing Module ---
30
- st.header("1. Data Preprocessor 🧹")
31
- st.write("Upload your CSV dataset or use a generated sample dataset.")
32
-
33
- # Option to use generated sample dataset
34
- if st.button("Use Sample Dataset (sample_dataset.csv)"):
35
- # In a real scenario, Streamlit would serve the file or you'd load it directly if local.
36
- # For this setup, we assume the Flask backend can access it or you manually upload it once.
37
- # For demonstration, we'll simulate loading a generic DataFrame.
38
- # In a full deployment, you'd have a mechanism to either:
39
- # a) Have Flask serve the sample file, or
40
- # b) Directly load it in Streamlit if the app and data are co-located.
41
- try:
42
- # Assuming the sample dataset is accessible or you are testing locally with `scripts/generate_data.py`
43
- # and then manually uploading this generated file.
44
- # For simplicity, we'll create a dummy df here if not actually uploaded.
45
- sample_df_path = "data/sample_dataset.csv" # Path relative to main.py or Streamlit app execution
46
- if os.path.exists(sample_df_path):
47
- sample_df = pd.read_csv(sample_df_path)
48
- st.success(f"Loaded sample dataset from {sample_df_path}. Please upload this file if running from different directory.")
49
- else:
50
- st.warning("Sample dataset not found at data/sample_dataset.csv.")
51
- # Dummy DataFrame for demonstration if sample file isn't found
52
- sample_df = pd.DataFrame(np.random.rand(10, 5), columns=[f'col_{i}' for i in range(5)])
53
-
54
- # Convert to JSON for Flask API call
55
- files = {'file': ('sample_dataset.csv', sample_df.to_csv(index=False), 'text/csv')}
56
- response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
57
-
58
- if response.status_code == 200:
59
- result = response.json()
60
- st.session_state.processed_data = result['data']
61
- st.session_state.processed_columns = result['columns']
62
- st.success("Sample dataset preprocessed successfully!")
63
- st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
64
- else:
65
- st.error(f"Error preprocessing sample dataset: {response.json().get('detail', 'Unknown error')}")
66
- except Exception as e:
67
- st.error(f"Could not load or process sample dataset: {e}")
68
-
69
-
70
- uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
71
- if uploaded_file is not None:
72
- st.info("Uploading and preprocessing data...")
73
- files = {'file': (uploaded_file.name, uploaded_file.getvalue(), 'text/csv')}
74
- try:
75
- response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
76
- if response.status_code == 200:
77
- result = response.json()
78
- st.session_state.processed_data = result['data']
79
- st.session_state.processed_columns = result['columns']
80
- st.success("File preprocessed successfully!")
81
- st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
82
- else:
83
- st.error(f"Error during preprocessing: {response.json().get('detail', 'Unknown error')}")
84
- except requests.exceptions.ConnectionError:
85
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
86
- except Exception as e:
87
- st.error(f"An unexpected error occurred: {e}")
88
-
89
- # --- Causal Discovery Module ---
90
- st.header("2. Causal Discovery 🕵️‍♂️")
91
- if st.session_state.processed_data:
92
- st.write("Learn the causal structure from your preprocessed data.")
93
-
94
- discovery_algo = st.selectbox(
95
- "Select Causal Discovery Algorithm:",
96
- ("PC Algorithm", "GES (Greedy Equivalence Search) - Placeholder", "NOTEARS - Placeholder")
97
- )
98
-
99
- if st.button("Discover Causal Graph"):
100
- st.info(f"Discovering graph using {discovery_algo}...")
101
- algo_map = {
102
- "PC Algorithm": "pc",
103
- "GES (Greedy Equivalence Search) - Placeholder": "ges",
104
- "NOTEARS - Placeholder": "notears"
105
- }
106
- selected_algo_code = algo_map[discovery_algo]
107
-
108
- try:
109
- response = requests.post(
110
- f"{FLASK_API_URL}/discover/",
111
- json={"data": st.session_state.processed_data, "algorithm": selected_algo_code}
112
- )
113
- if response.status_code == 200:
114
- result = response.json()
115
- st.session_state.causal_graph_adj = result['graph']
116
- st.session_state.causal_graph_nodes = st.session_state.processed_columns
117
- st.success("Causal graph discovered!")
118
- st.subheader("Causal Graph Visualization")
119
- # Visualization will be handled by the Causal Graph Visualizer section
120
- else:
121
- st.error(f"Error during causal discovery: {response.json().get('detail', 'Unknown error')}")
122
- except requests.exceptions.ConnectionError:
123
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
124
- except Exception as e:
125
- st.error(f"An unexpected error occurred: {e}")
126
- else:
127
- st.info("Please preprocess data first to enable causal discovery.")
128
-
129
- # --- Causal Graph Visualizer Module ---
130
- st.header("3. Causal Graph Visualizer 📊")
131
- if st.session_state.causal_graph_adj and st.session_state.causal_graph_nodes:
132
- st.write("Interactive visualization of the discovered causal graph.")
133
- try:
134
- response = requests.post(
135
- f"{FLASK_API_URL}/visualize/graph",
136
- json={"graph": st.session_state.causal_graph_adj, "nodes": st.session_state.causal_graph_nodes}
137
- )
138
- if response.status_code == 200:
139
- graph_json = response.json()['graph']
140
- fig = go.Figure(json.loads(graph_json))
141
- st.plotly_chart(fig, use_container_width=True)
142
- st.markdown("""
143
- **Graph Explanation:**
144
- * **Nodes:** Represent variables in your dataset.
145
- * **Arrows (Edges):** Indicate a direct causal influence from one variable (the tail) to another (the head).
146
- * **No Arrow:** Suggests no direct causal relationship was found, or the relationship is mediated by other variables.
147
-
148
- This graph helps answer "Why did it happen?" by showing the structural relationships.
149
- """)
150
- else:
151
- st.error(f"Error visualizing graph: {response.json().get('detail', 'Unknown error')}")
152
- except requests.exceptions.ConnectionError:
153
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
154
- except Exception as e:
155
- st.error(f"An unexpected error occurred during visualization: {e}")
156
- else:
157
- st.info("Please discover a causal graph first to visualize it.")
158
-
159
-
160
- # --- Do-Calculus Engine Module ---
161
- st.header("4. Do-Calculus Engine 🧪")
162
- if st.session_state.processed_data and st.session_state.causal_graph_adj:
163
- st.write("Simulate interventions and observe their effects based on the causal graph.")
164
-
165
- intervention_var = st.selectbox(
166
- "Select variable to intervene on:",
167
- st.session_state.processed_columns,
168
- key="inter_var_select"
169
- )
170
- # Attempt to infer type for intervention_value input
171
- # Simplified approach: assuming numerical for now due to preprocessor output
172
- if intervention_var and isinstance(st.session_state.processed_data[0][intervention_var], (int, float)):
173
- intervention_value = st.number_input(f"Set '{intervention_var}' to value:", key="inter_val_input")
174
- else: # Treat as string/categorical for input, then try to preprocess for API
175
- intervention_value = st.text_input(f"Set '{intervention_var}' to value:", key="inter_val_input_text")
176
- st.warning("Categorical intervention values might require specific encoding logic on the backend.")
177
-
178
- if st.button("Perform Intervention"):
179
- st.info(f"Performing intervention: do('{intervention_var}' = {intervention_value})...")
180
- try:
181
- response = requests.post(
182
- f"{FLASK_API_URL}/intervene/",
183
- json={
184
- "data": st.session_state.processed_data,
185
- "intervention_var": intervention_var,
186
- "intervention_value": intervention_value,
187
- "graph": st.session_state.causal_graph_adj # Pass graph for advanced do-calculus
188
- }
189
- )
190
- if response.status_code == 200:
191
- intervened_data = pd.DataFrame(response.json()['intervened_data'])
192
- st.success("Intervention simulated successfully!")
193
- st.subheader("Intervened Data (First 10 rows)")
194
- st.dataframe(intervened_data.head(10))
195
-
196
- # Simple comparison visualization (e.g., histogram of outcome variable)
197
- if st.session_state.processed_columns and 'FinalExamScore' in st.session_state.processed_columns:
198
- original_df = pd.DataFrame(st.session_state.processed_data)
199
- fig_dist = go.Figure()
200
- fig_dist.add_trace(go.Histogram(x=original_df['FinalExamScore'], name='Original', opacity=0.7))
201
- fig_dist.add_trace(go.Histogram(x=intervened_data['FinalExamScore'], name='Intervened', opacity=0.0))
202
-
203
- st.plotly_chart(fig_dist, use_container_width=True)
204
- st.markdown("""
205
- **Intervention Explanation:**
206
- * By simulating `do(X=x)`, we are forcing the value of X, effectively breaking its causal links from its parents.
207
- * The graph above shows the distribution of a key outcome variable (e.g., `FinalExamScore`) before and after the intervention.
208
- * This helps answer "What if we do this instead?" by showing the predicted outcome.
209
- """)
210
- else:
211
- st.info("Consider adding a relevant outcome variable to your dataset for better intervention analysis.")
212
- else:
213
- st.error(f"Error during intervention: {response.json().get('detail', 'Unknown error')}")
214
- except requests.exceptions.ConnectionError:
215
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
216
- except Exception as e:
217
- st.error(f"An unexpected error occurred during intervention: {e}")
218
- else:
219
- st.info("Please preprocess data and discover a causal graph first to perform interventions.")
220
-
221
- # --- Treatment Effect Estimator Module ---
222
- st.header("5. Treatment Effect Estimator 🎯")
223
- if st.session_state.processed_data:
224
- st.write("Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).")
225
-
226
- col1, col2 = st.columns(2)
227
- with col1:
228
- treatment_col = st.selectbox(
229
- "Select Treatment Variable:",
230
- st.session_state.processed_columns,
231
- key="treat_col_select"
232
- )
233
- with col2:
234
- outcome_col = st.selectbox(
235
- "Select Outcome Variable:",
236
- st.session_state.processed_columns,
237
- key="outcome_col_select"
238
- )
239
-
240
- all_cols_except_treat_outcome = [col for col in st.session_state.processed_columns if col not in [treatment_col, outcome_col]]
241
- covariates = st.multiselect(
242
- "Select Covariates (confounders):",
243
- all_cols_except_treat_outcome,
244
- default=all_cols_except_treat_outcome, # Default to all other columns
245
- key="covariates_select"
246
- )
247
-
248
- estimation_method = st.selectbox(
249
- "Select Estimation Method:",
250
- (
251
- "Linear Regression ATE",
252
- "Propensity Score Matching - Placeholder",
253
- "Inverse Propensity Weighting - Placeholder",
254
- "T-learner - Placeholder",
255
- "S-learner - Placeholder"
256
- )
257
- )
258
-
259
- if st.button("Estimate Treatment Effect"):
260
- st.info(f"Estimating treatment effect using {estimation_method}...")
261
- method_map = {
262
- "Linear Regression ATE": "linear_regression",
263
- "Propensity Score Matching - Placeholder": "propensity_score_matching",
264
- "Inverse Propensity Weighting - Placeholder": "inverse_propensity_weighting",
265
- "T-learner - Placeholder": "t_learner",
266
- "S-learner - Placeholder": "s_learner"
267
- }
268
- selected_method_code = method_map[estimation_method]
269
-
270
- try:
271
- response = requests.post(
272
- f"{FLASK_API_URL}/treatment/estimate_ate",
273
- json={
274
- "data": st.session_state.processed_data,
275
- "treatment_col": treatment_col,
276
- "outcome_col": outcome_col,
277
- "covariates": covariates,
278
- "method": selected_method_code
279
- }
280
- )
281
- if response.status_code == 200:
282
- ate_result = response.json()['result']
283
- st.success(f"Treatment effect estimated using {estimation_method}:")
284
- st.write(f"**Estimated ATE: {ate_result:.4f}**")
285
- st.markdown("""
286
- **Treatment Effect Explanation:**
287
- * **Average Treatment Effect (ATE):** Measures the average causal effect of a treatment (e.g., `StudyHours`) on an outcome (e.g., `FinalExamScore`) across the entire population.
288
- * It answers "How much does doing X cause a change in Y?".
289
- * This estimation attempts to control for confounders (variables that influence both treatment and outcome) to isolate the true causal effect.
290
- """)
291
- else:
292
- st.error(f"Error during ATE estimation: {response.json().get('detail', 'Unknown error')}")
293
- except requests.exceptions.ConnectionError:
294
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
295
- except Exception as e:
296
- st.error(f"An unexpected error occurred during ATE estimation: {e}")
297
- else:
298
- st.info("Please preprocess data first to estimate treatment effects.")
299
-
300
- # --- Prediction Module ---
301
- st.header("6. Prediction Module 📈")
302
- if st.session_state.processed_data:
303
- st.write("Train a machine learning model for prediction (Regression or Classification).")
304
-
305
- prediction_type = st.selectbox(
306
- "Select Prediction Type:",
307
- ("Regression", "Classification"),
308
- key="prediction_type_select"
309
- )
310
-
311
- all_columns = st.session_state.processed_columns
312
-
313
- suitable_target_columns = []
314
- if st.session_state.processed_data:
315
- temp_df = pd.DataFrame(st.session_state.processed_data)
316
- for col in all_columns:
317
- # For classification, check if column is object type (string), boolean,
318
- # or has a limited number of unique integer values (e.g., less than 20 unique values)
319
- if prediction_type == 'Classification':
320
- if temp_df[col].dtype == 'object' or temp_df[col].dtype == 'bool':
321
- suitable_target_columns.append(col)
322
- elif pd.api.types.is_integer_dtype(temp_df[col]) and temp_df[col].nunique() < 20: # Heuristic for discrete integers
323
- suitable_target_columns.append(col)
324
- # For regression, primarily numerical columns
325
- elif prediction_type == 'Regression':
326
- if pd.api.types.is_numeric_dtype(temp_df[col]):
327
- suitable_target_columns.append(col)
328
-
329
- if not suitable_target_columns:
330
- st.warning(f"No suitable target columns found for {prediction_type}. Please check your data types.")
331
- target_col = None # Set to None to prevent error if no columns are found
332
- else:
333
- # Try to pre-select the currently chosen target_col if it's still suitable
334
- # Otherwise, default to the first suitable column
335
- if 'target_col_select' in st.session_state and st.session_state.target_col_select in suitable_target_columns:
336
- default_target_index = suitable_target_columns.index(st.session_state.target_col_select)
337
- else:
338
- default_target_index = 0
339
-
340
- target_col = st.selectbox(
341
- "Select Target Variable:",
342
- suitable_target_columns,
343
- index=default_target_index,
344
- key="target_col_select"
345
- )
346
-
347
- # Filter out the target column from feature options
348
- feature_options = [col for col in all_columns if col != target_col]
349
- feature_cols = st.multiselect(
350
- "Select Feature Variables:",
351
- feature_options,
352
- default=feature_options, # Default to all other columns
353
- key="feature_cols_select"
354
- )
355
-
356
- if st.button("Train Model & Predict", key="train_predict_button"):
357
- if not target_col or not feature_cols:
358
- st.warning("Please select a target variable and at least one feature variable.")
359
- else:
360
- st.info(f"Training {prediction_type} model using Random Forest...")
361
- try:
362
- response = requests.post(
363
- f"{FLASK_API_URL}/prediction/train_predict",
364
- json={
365
- "data": st.session_state.processed_data,
366
- "target_col": target_col,
367
- "feature_cols": feature_cols,
368
- "prediction_type": prediction_type.lower()
369
- }
370
- )
371
-
372
- if response.status_code == 200:
373
- results = response.json()['results']
374
- st.success(f"{prediction_type} Model Trained Successfully!")
375
- st.subheader("Model Performance")
376
-
377
- if prediction_type == 'Regression':
378
- st.write(f"**R-squared:** {results['r2_score']:.4f}")
379
- st.write(f"**Mean Squared Error (MSE):** {results['mean_squared_error']:.4f}")
380
- st.write(f"**Root Mean Squared Error (RMSE):** {results['root_mean_squared_error']:.4f}")
381
-
382
- st.subheader("Actual vs. Predicted Plot")
383
- actual_predicted_df = pd.DataFrame(results['actual_vs_predicted'])
384
- fig_reg = px.scatter(actual_predicted_df, x='Actual', y='Predicted',
385
- title='Actual vs. Predicted Values',
386
- labels={'Actual': f'Actual {target_col}', 'Predicted': f'Predicted {target_col}'})
387
- fig_reg.add_trace(go.Scatter(x=[actual_predicted_df['Actual'].min(), actual_predicted_df['Actual'].max()],
388
- y=[actual_predicted_df['Actual'].min(), actual_predicted_df['Actual'].max()],
389
- mode='lines', name='Ideal Fit', line=dict(dash='dash', color='red')))
390
- st.plotly_chart(fig_reg, use_container_width=True)
391
-
392
- st.subheader("Residual Plot")
393
- actual_predicted_df['Residuals'] = actual_predicted_df['Actual'] - actual_predicted_df['Predicted']
394
- fig_res = px.scatter(actual_predicted_df, x='Predicted', y='Residuals',
395
- title='Residual Plot',
396
- labels={'Predicted': f'Predicted {target_col}', 'Residuals': 'Residuals'})
397
- fig_res.add_hline(y=0, line_dash="dash", line_color="red")
398
- st.plotly_chart(fig_res, use_container_width=True)
399
-
400
- elif prediction_type == 'Classification':
401
- st.write(f"**Accuracy:** {results['accuracy']:.4f}")
402
- st.write(f"**Precision (weighted):** {results['precision']:.4f}")
403
- st.write(f"**Recall (weighted):** {results['recall']:.4f}")
404
- st.write(f"**F1-Score (weighted):** {results['f1_score']:.4f}")
405
-
406
- st.subheader("Confusion Matrix")
407
- conf_matrix = results['confusion_matrix']
408
- class_labels = results.get('class_labels', [str(i) for i in range(len(conf_matrix))])
409
- fig_cm = px.imshow(conf_matrix,
410
- labels=dict(x="Predicted", y="True", color="Count"),
411
- x=class_labels,
412
- y=class_labels,
413
- text_auto=True,
414
- color_continuous_scale="Viridis",
415
- title="Confusion Matrix")
416
- st.plotly_chart(fig_cm, use_container_width=True)
417
-
418
- st.subheader("Classification Report")
419
- # Convert dict to DataFrame for nice display
420
- report_df = pd.DataFrame(results['classification_report']).transpose()
421
- st.dataframe(report_df)
422
-
423
- st.subheader("Feature Importances")
424
- feature_importances_df = pd.DataFrame(list(results['feature_importances'].items()), columns=['Feature', 'Importance'])
425
- fig_fi = px.bar(feature_importances_df, x='Importance', y='Feature', orientation='h',
426
- title='Feature Importances',
427
- labels={'Importance': 'Importance Score', 'Feature': 'Feature Name'})
428
- fig_fi.update_layout(yaxis={'categoryorder':'total ascending'}) # Sort bars
429
- st.plotly_chart(fig_fi, use_container_width=True)
430
- else:
431
- st.error(f"Error during prediction: {response.json().get('detail', 'Unknown error')}")
432
- except requests.exceptions.ConnectionError:
433
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
434
- except Exception as e:
435
- st.error(f"An unexpected error occurred during prediction: {e}")
436
- else:
437
- st.info("Please preprocess data first to use the Prediction Module.")
438
-
439
- # --- Time Series Causal Discovery Module ---
440
- st.header("7. Time Series Causal Discovery ⏰")
441
- if st.session_state.processed_data:
442
- st.write("Infer causal relationships in time-series data using Granger Causality.")
443
- st.info("Ensure your dataset includes a timestamp column and that variables are numeric.")
444
-
445
- all_columns = st.session_state.processed_columns
446
-
447
- # Heuristic to suggest potential timestamp columns (object/string type, or first column)
448
- potential_ts_cols = [col for col in all_columns if pd.DataFrame(st.session_state.processed_data)[col].dtype == 'object']
449
- if not potential_ts_cols and all_columns: # If no object columns, suggest the first column
450
- potential_ts_cols = [all_columns[0]]
451
-
452
- timestamp_col = st.selectbox(
453
- "Select Timestamp Column:",
454
- potential_ts_cols if potential_ts_cols else ["No suitable timestamp column found. Please check data."],
455
- key="ts_col_select"
456
- )
457
-
458
- # Filter out timestamp column and non-numeric columns for analysis
459
- variables_for_ts_analysis = [
460
- col for col in all_columns if col != timestamp_col and pd.api.types.is_numeric_dtype(pd.DataFrame(st.session_state.processed_data)[col])
461
- ]
462
-
463
- variables_to_analyze = st.multiselect(
464
- "Select Variables to Analyze for Granger Causality:",
465
- variables_for_ts_analysis,
466
- default=variables_for_ts_analysis,
467
- key="ts_vars_select"
468
- )
469
-
470
- max_lags = st.number_input(
471
- "Max Lags (for Granger Causality):",
472
- min_value=1,
473
- value=5, # Default value
474
- step=1,
475
- help="The maximum number of lagged observations to consider for causality."
476
- )
477
-
478
- if st.button("Discover Time Series Causality", key="ts_discover_button"):
479
- if not timestamp_col or not variables_to_analyze:
480
- st.warning("Please select a timestamp column and at least one variable to analyze.")
481
- elif "No suitable timestamp column found" in timestamp_col:
482
- st.error("Cannot proceed. Please ensure your data has a suitable timestamp column.")
483
- else:
484
- st.info("Performing Granger Causality tests...")
485
- try:
486
- response = requests.post(
487
- f"{FLASK_API_URL}/timeseries/discover_causality",
488
- json={
489
- "data": st.session_state.processed_data,
490
- "timestamp_col": timestamp_col,
491
- "variables_to_analyze": variables_to_analyze,
492
- "max_lags": max_lags
493
- }
494
- )
495
-
496
- if response.status_code == 200:
497
- results = response.json()['results']
498
- st.success("Time Series Causal Discovery Complete!")
499
- st.subheader("Granger Causality Test Results")
500
-
501
- if results:
502
- # Convert results to a DataFrame for better display
503
- results_df = pd.DataFrame(results)
504
- results_df['p_value'] = results_df['p_value'].round(4) # Round p-values
505
- st.dataframe(results_df)
506
-
507
- st.markdown("**Interpretation:** A small p-value (typically < 0.05) suggests that the 'cause' variable Granger-causes the 'effect' variable. This means past values of the 'cause' variable help predict future values of the 'effect' variable, even when past values of the 'effect' variable are considered.")
508
- st.markdown(f"*(Note: Granger Causality implies predictive causality, not necessarily true mechanistic causality. Also, ensure your time series are stationary for robust results.)*")
509
-
510
- # Optionally, visualize a simple causality graph
511
- st.subheader("Granger Causality Graph")
512
- fig_ts_graph = go.Figure()
513
- nodes = []
514
- edges = []
515
- edge_colors = []
516
-
517
- # Add nodes
518
- for i, var in enumerate(variables_to_analyze):
519
- nodes.append(dict(id=var, label=var, x=np.cos(i*2*np.pi/len(variables_to_analyze)), y=np.sin(i*2*np.pi/len(variables_to_analyze))))
520
-
521
- # Add edges
522
- for res in results:
523
- if res['p_value'] < 0.05: # Consider it a causal link if p-value is below significance
524
- edges.append(dict(source=res['cause'], target=res['effect'], value=1/res['p_value'], title=f"p={res['p_value']:.4f}"))
525
- edge_colors.append("blue")
526
- else:
527
- # Optional: Show non-significant edges in a different color or omit
528
- pass
529
-
530
- # Use a simple network graph layout (Spring layout is common)
531
- # For a truly interactive graph, you might need a different library or more complex Plotly setup
532
- # This is a very basic attempt to visualize; consider more robust solutions like NetworkX + Plotly/Dash
533
-
534
- # Simple way to draw arrows for significant relationships
535
- significant_edges = [edge for edge in results if edge['p_value'] < 0.05]
536
- if significant_edges:
537
- st.write("Visualizing significant (p < 0.05) Granger causal links:")
538
- # This needs a more robust way to draw directed edges in plotly if using just scatter/lines.
539
- # For now, let's just list them clearly.
540
- for edge in significant_edges:
541
- st.write(f"➡️ **{edge['cause']}** Granger-causes **{edge['effect']}** (p={edge['p_value']:.4f})")
542
- else:
543
- st.info("No significant Granger causal links found at p < 0.05.")
544
-
545
- else:
546
- st.info("No Granger Causality relationships found or data insufficient.")
547
-
548
- else:
549
- st.error(f"Error during time-series causal discovery: {response.json().get('detail', 'Unknown error')}")
550
- except requests.exceptions.ConnectionError:
551
- st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
552
- except Exception as e:
553
- st.error(f"An unexpected error occurred during time-series causal discovery: {e}")
554
- else:
555
- st.info("Please preprocess data first to use the Time Series Causal Discovery Module.")
556
-
557
- # --- CausalBox Chat Assistant ---
558
- st.header("8. CausalBox Chat Assistant 🤖")
559
- st.write("Ask questions about your loaded dataset, causal concepts, or the discovered causal graph!")
560
-
561
- # Initialize chat history in session state
562
- if "messages" not in st.session_state:
563
- st.session_state.messages = []
564
-
565
- # Display chat messages from history on app rerun
566
- for message in st.session_state.messages:
567
- with st.chat_message(message["role"]):
568
- st.markdown(message["content"])
569
-
570
- # Accept user input
571
- if prompt := st.chat_input("Ask me anything about CausalBox..."):
572
- # Add user message to chat history
573
- st.session_state.messages.append({"role": "user", "content": prompt})
574
- # Display user message in chat message container
575
- with st.chat_message("user"):
576
- st.markdown(prompt)
577
-
578
- # Prepare session context to send to the backend
579
- session_context = {
580
- "processed_data": st.session_state.processed_data,
581
- "processed_columns": st.session_state.processed_columns,
582
- "causal_graph_adj": st.session_state.causal_graph_adj,
583
- "causal_graph_nodes": st.session_state.causal_graph_nodes,
584
- # Add any other relevant session state variables that the chatbot might need
585
- }
586
-
587
- with st.spinner("Thinking..."):
588
- try:
589
- response = requests.post(
590
- f"{FLASK_API_URL}/chatbot/message",
591
- json={
592
- "user_message": prompt,
593
- "session_context": session_context
594
- }
595
- )
596
-
597
- if response.status_code == 200:
598
- chatbot_response_text = response.json().get('response', 'Sorry, I could not generate a response.')
599
- else:
600
- chatbot_response_text = f"Error from chatbot backend: {response.json().get('detail', 'Unknown error')}"
601
- except requests.exceptions.ConnectionError:
602
- chatbot_response_text = f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running."
603
- except Exception as e:
604
- chatbot_response_text = f"An unexpected error occurred while getting chatbot response: {e}"
605
-
606
- # Display assistant response in chat message container
607
- with st.chat_message("assistant"):
608
- st.markdown(chatbot_response_text)
609
- # Add assistant response to chat history
610
- st.session_state.messages.append({"role": "assistant", "content": chatbot_response_text})
611
-
612
- # --- Future Work (Simplified) ---
613
- st.header("Future Work 🚀")
614
- st.markdown("""
615
- - **🔄 Auto-causal graph refresh:** Monitor dataset updates and automatically refresh the causal graph.
616
- """)
617
-
618
- st.markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  st.info("Developed by CausalBox Team. For support, please contact us.")
 
1
+ # streamlit_app.py
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import requests
5
+ import json
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+ import numpy as np # For random array in placeholders
9
+ import os
10
+ import io
11
+ # Configuration
12
+ FLASK_API_URL = "http://localhost:5000" # Ensure this matches your Flask app's host and port
13
+
14
+ st.set_page_config(layout="wide", page_title="CausalBox Toolkit")
15
+
16
+ st.title("🔬 CausalBox: A Causal Inference Toolkit")
17
+ st.markdown("Uncover causal relationships, simulate interventions, and estimate treatment effects.")
18
+
19
+ # --- Session State Initialization ---
20
+ if 'processed_data' not in st.session_state:
21
+ st.session_state.processed_data = None
22
+ if 'processed_columns' not in st.session_state:
23
+ st.session_state.processed_columns = None
24
+ if 'causal_graph_adj' not in st.session_state:
25
+ st.session_state.causal_graph_adj = None
26
+ if 'causal_graph_nodes' not in st.session_state:
27
+ st.session_state.causal_graph_nodes = None
28
+
29
+ # --- Data Preprocessing Module ---
30
+ st.header("1. Data Preprocessor 🧹")
31
+ st.write("Upload your CSV dataset or use a generated sample dataset.")
32
+
33
+ # Option to use generated sample dataset
34
+ if st.button("Use Sample Dataset (sample_dataset.csv)"):
35
+ # Path to the sample_dataset.csv relative to streamlit_app.py
36
+ # Assumes sample_dataset.csv is in the 'data' folder at the root of the project
37
+ sample_csv_path = os.path.join(os.path.dirname(__file__), 'data', 'sample_dataset.csv')
38
+
39
+ if os.path.exists(sample_csv_path):
40
+ with open(sample_csv_path, 'rb') as f:
41
+ csv_content = f.read()
42
+
43
+ # Prepare the file for upload using 'files' parameter for multipart/form-data
44
+ # 'file' is the name of the input field Flask expects (request.files['file'])
45
+ # 'sample_dataset.csv' is the filename
46
+ # csv_content is the actual binary content of the file
47
+ # 'text/csv' is the content type
48
+ files = {'file': ('sample_dataset.csv', csv_content, 'text/csv')}
49
+
50
+ try:
51
+ # Send the file to Flask backend
52
+ response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
53
+ response.raise_for_status() # Raise an HTTPError for bad responses (4xx or 5xx)
54
+ processed_data_json = response.json()
55
+
56
+ # Update Streamlit session state with processed data and columns
57
+ st.session_state.processed_data = processed_data_json['data']
58
+ st.session_state.processed_columns = processed_data_json['columns']
59
+ st.success("Sample dataset loaded and preprocessed successfully!")
60
+
61
+ # Optional: Display the columns or a snippet of data for confirmation
62
+ st.json(processed_data_json['columns'])
63
+
64
+ except requests.exceptions.ConnectionError:
65
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
66
+ except requests.exceptions.HTTPError as http_err: # Catch HTTPError specifically for detailed error
67
+ st.error(f"HTTP error occurred: {http_err} - Server response: {http_err.response.text}")
68
+ except Exception as e:
69
+ st.error(f"An unexpected error occurred: {e}")
70
+ else:
71
+ st.error(f"Sample dataset not found at {sample_csv_path}. Please ensure it exists in your 'data' folder.")
72
+
73
+ if response.status_code == 200:
74
+ result = response.json()
75
+ st.session_state.processed_data = result['data']
76
+ st.session_state.processed_columns = result['columns']
77
+ st.success("Sample dataset preprocessed successfully!")
78
+ st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
79
+ else:
80
+ st.error(f"Error preprocessing sample dataset: {response.json().get('detail', 'Unknown error')}")
81
+ except Exception as e:
82
+ st.error(f"Could not load or process sample dataset: {e}")
83
+
84
+
85
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
86
+ if uploaded_file is not None:
87
+ st.info("Uploading and preprocessing data...")
88
+ files = {'file': (uploaded_file.name, uploaded_file.getvalue(), 'text/csv')}
89
+ try:
90
+ response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
91
+ if response.status_code == 200:
92
+ result = response.json()
93
+ st.session_state.processed_data = result['data']
94
+ st.session_state.processed_columns = result['columns']
95
+ st.success("File preprocessed successfully!")
96
+ st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
97
+ else:
98
+ st.error(f"Error during preprocessing: {response.json().get('detail', 'Unknown error')}")
99
+ except requests.exceptions.ConnectionError:
100
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
101
+ except Exception as e:
102
+ st.error(f"An unexpected error occurred: {e}")
103
+
104
+ # --- Causal Discovery Module ---
105
+ st.header("2. Causal Discovery 🕵️‍♂️")
106
+ if st.session_state.processed_data:
107
+ st.write("Learn the causal structure from your preprocessed data.")
108
+
109
+ discovery_algo = st.selectbox(
110
+ "Select Causal Discovery Algorithm:",
111
+ ("PC Algorithm", "GES (Greedy Equivalence Search) - Placeholder", "NOTEARS - Placeholder")
112
+ )
113
+
114
+ if st.button("Discover Causal Graph"):
115
+ st.info(f"Discovering graph using {discovery_algo}...")
116
+ algo_map = {
117
+ "PC Algorithm": "pc",
118
+ "GES (Greedy Equivalence Search) - Placeholder": "ges",
119
+ "NOTEARS - Placeholder": "notears"
120
+ }
121
+ selected_algo_code = algo_map[discovery_algo]
122
+
123
+ try:
124
+ response = requests.post(
125
+ f"{FLASK_API_URL}/discover/",
126
+ json={"data": st.session_state.processed_data, "algorithm": selected_algo_code}
127
+ )
128
+ if response.status_code == 200:
129
+ result = response.json()
130
+ st.session_state.causal_graph_adj = result['graph']
131
+ st.session_state.causal_graph_nodes = st.session_state.processed_columns
132
+ st.success("Causal graph discovered!")
133
+ st.subheader("Causal Graph Visualization")
134
+ # Visualization will be handled by the Causal Graph Visualizer section
135
+ else:
136
+ st.error(f"Error during causal discovery: {response.json().get('detail', 'Unknown error')}")
137
+ except requests.exceptions.ConnectionError:
138
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
139
+ except Exception as e:
140
+ st.error(f"An unexpected error occurred: {e}")
141
+ else:
142
+ st.info("Please preprocess data first to enable causal discovery.")
143
+
144
+ # --- Causal Graph Visualizer Module ---
145
+ st.header("3. Causal Graph Visualizer 📊")
146
+ if st.session_state.causal_graph_adj and st.session_state.causal_graph_nodes:
147
+ st.write("Interactive visualization of the discovered causal graph.")
148
+ try:
149
+ response = requests.post(
150
+ f"{FLASK_API_URL}/visualize/graph",
151
+ json={"graph": st.session_state.causal_graph_adj, "nodes": st.session_state.causal_graph_nodes}
152
+ )
153
+ if response.status_code == 200:
154
+ graph_json = response.json()['graph']
155
+ fig = go.Figure(json.loads(graph_json))
156
+ st.plotly_chart(fig, use_container_width=True)
157
+ st.markdown("""
158
+ **Graph Explanation:**
159
+ * **Nodes:** Represent variables in your dataset.
160
+ * **Arrows (Edges):** Indicate a direct causal influence from one variable (the tail) to another (the head).
161
+ * **No Arrow:** Suggests no direct causal relationship was found, or the relationship is mediated by other variables.
162
+
163
+ This graph helps answer "Why did it happen?" by showing the structural relationships.
164
+ """)
165
+ else:
166
+ st.error(f"Error visualizing graph: {response.json().get('detail', 'Unknown error')}")
167
+ except requests.exceptions.ConnectionError:
168
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
169
+ except Exception as e:
170
+ st.error(f"An unexpected error occurred during visualization: {e}")
171
+ else:
172
+ st.info("Please discover a causal graph first to visualize it.")
173
+
174
+
175
+ # --- Do-Calculus Engine Module ---
176
+ st.header("4. Do-Calculus Engine 🧪")
177
+ if st.session_state.processed_data and st.session_state.causal_graph_adj:
178
+ st.write("Simulate interventions and observe their effects based on the causal graph.")
179
+
180
+ intervention_var = st.selectbox(
181
+ "Select variable to intervene on:",
182
+ st.session_state.processed_columns,
183
+ key="inter_var_select"
184
+ )
185
+ # Attempt to infer type for intervention_value input
186
+ # Simplified approach: assuming numerical for now due to preprocessor output
187
+ if intervention_var and isinstance(st.session_state.processed_data[0][intervention_var], (int, float)):
188
+ intervention_value = st.number_input(f"Set '{intervention_var}' to value:", key="inter_val_input")
189
+ else: # Treat as string/categorical for input, then try to preprocess for API
190
+ intervention_value = st.text_input(f"Set '{intervention_var}' to value:", key="inter_val_input_text")
191
+ st.warning("Categorical intervention values might require specific encoding logic on the backend.")
192
+
193
+ if st.button("Perform Intervention"):
194
+ st.info(f"Performing intervention: do('{intervention_var}' = {intervention_value})...")
195
+ try:
196
+ response = requests.post(
197
+ f"{FLASK_API_URL}/intervene/",
198
+ json={
199
+ "data": st.session_state.processed_data,
200
+ "intervention_var": intervention_var,
201
+ "intervention_value": intervention_value,
202
+ "graph": st.session_state.causal_graph_adj # Pass graph for advanced do-calculus
203
+ }
204
+ )
205
+ if response.status_code == 200:
206
+ intervened_data = pd.DataFrame(response.json()['intervened_data'])
207
+ st.success("Intervention simulated successfully!")
208
+ st.subheader("Intervened Data (First 10 rows)")
209
+ st.dataframe(intervened_data.head(10))
210
+
211
+ # Simple comparison visualization (e.g., histogram of outcome variable)
212
+ if st.session_state.processed_columns and 'FinalExamScore' in st.session_state.processed_columns:
213
+ original_df = pd.DataFrame(st.session_state.processed_data)
214
+ fig_dist = go.Figure()
215
+ fig_dist.add_trace(go.Histogram(x=original_df['FinalExamScore'], name='Original', opacity=0.7))
216
+ fig_dist.add_trace(go.Histogram(x=intervened_data['FinalExamScore'], name='Intervened', opacity=0.0))
217
+
218
+ st.plotly_chart(fig_dist, use_container_width=True)
219
+ st.markdown("""
220
+ **Intervention Explanation:**
221
+ * By simulating `do(X=x)`, we are forcing the value of X, effectively breaking its causal links from its parents.
222
+ * The graph above shows the distribution of a key outcome variable (e.g., `FinalExamScore`) before and after the intervention.
223
+ * This helps answer "What if we do this instead?" by showing the predicted outcome.
224
+ """)
225
+ else:
226
+ st.info("Consider adding a relevant outcome variable to your dataset for better intervention analysis.")
227
+ else:
228
+ st.error(f"Error during intervention: {response.json().get('detail', 'Unknown error')}")
229
+ except requests.exceptions.ConnectionError:
230
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
231
+ except Exception as e:
232
+ st.error(f"An unexpected error occurred during intervention: {e}")
233
+ else:
234
+ st.info("Please preprocess data and discover a causal graph first to perform interventions.")
235
+
236
+ # --- Treatment Effect Estimator Module ---
237
+ st.header("5. Treatment Effect Estimator 🎯")
238
+ if st.session_state.processed_data:
239
+ st.write("Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).")
240
+
241
+ col1, col2 = st.columns(2)
242
+ with col1:
243
+ treatment_col = st.selectbox(
244
+ "Select Treatment Variable:",
245
+ st.session_state.processed_columns,
246
+ key="treat_col_select"
247
+ )
248
+ with col2:
249
+ outcome_col = st.selectbox(
250
+ "Select Outcome Variable:",
251
+ st.session_state.processed_columns,
252
+ key="outcome_col_select"
253
+ )
254
+
255
+ all_cols_except_treat_outcome = [col for col in st.session_state.processed_columns if col not in [treatment_col, outcome_col]]
256
+ covariates = st.multiselect(
257
+ "Select Covariates (confounders):",
258
+ all_cols_except_treat_outcome,
259
+ default=all_cols_except_treat_outcome, # Default to all other columns
260
+ key="covariates_select"
261
+ )
262
+
263
+ estimation_method = st.selectbox(
264
+ "Select Estimation Method:",
265
+ (
266
+ "Linear Regression ATE",
267
+ "Propensity Score Matching - Placeholder",
268
+ "Inverse Propensity Weighting - Placeholder",
269
+ "T-learner - Placeholder",
270
+ "S-learner - Placeholder"
271
+ )
272
+ )
273
+
274
+ if st.button("Estimate Treatment Effect"):
275
+ st.info(f"Estimating treatment effect using {estimation_method}...")
276
+ method_map = {
277
+ "Linear Regression ATE": "linear_regression",
278
+ "Propensity Score Matching - Placeholder": "propensity_score_matching",
279
+ "Inverse Propensity Weighting - Placeholder": "inverse_propensity_weighting",
280
+ "T-learner - Placeholder": "t_learner",
281
+ "S-learner - Placeholder": "s_learner"
282
+ }
283
+ selected_method_code = method_map[estimation_method]
284
+
285
+ try:
286
+ response = requests.post(
287
+ f"{FLASK_API_URL}/treatment/estimate_ate",
288
+ json={
289
+ "data": st.session_state.processed_data,
290
+ "treatment_col": treatment_col,
291
+ "outcome_col": outcome_col,
292
+ "covariates": covariates,
293
+ "method": selected_method_code
294
+ }
295
+ )
296
+ if response.status_code == 200:
297
+ ate_result = response.json()['result']
298
+ st.success(f"Treatment effect estimated using {estimation_method}:")
299
+ st.write(f"**Estimated ATE: {ate_result:.4f}**")
300
+ st.markdown("""
301
+ **Treatment Effect Explanation:**
302
+ * **Average Treatment Effect (ATE):** Measures the average causal effect of a treatment (e.g., `StudyHours`) on an outcome (e.g., `FinalExamScore`) across the entire population.
303
+ * It answers "How much does doing X cause a change in Y?".
304
+ * This estimation attempts to control for confounders (variables that influence both treatment and outcome) to isolate the true causal effect.
305
+ """)
306
+ else:
307
+ st.error(f"Error during ATE estimation: {response.json().get('detail', 'Unknown error')}")
308
+ except requests.exceptions.ConnectionError:
309
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
310
+ except Exception as e:
311
+ st.error(f"An unexpected error occurred during ATE estimation: {e}")
312
+ else:
313
+ st.info("Please preprocess data first to estimate treatment effects.")
314
+
315
+ # --- Prediction Module ---
316
+ st.header("6. Prediction Module 📈")
317
+ if st.session_state.processed_data:
318
+ st.write("Train a machine learning model for prediction (Regression or Classification).")
319
+
320
+ prediction_type = st.selectbox(
321
+ "Select Prediction Type:",
322
+ ("Regression", "Classification"),
323
+ key="prediction_type_select"
324
+ )
325
+
326
+ all_columns = st.session_state.processed_columns
327
+
328
+ suitable_target_columns = []
329
+ if st.session_state.processed_data:
330
+ temp_df = pd.DataFrame(st.session_state.processed_data)
331
+ for col in all_columns:
332
+ # For classification, check if column is object type (string), boolean,
333
+ # or has a limited number of unique integer values (e.g., less than 20 unique values)
334
+ if prediction_type == 'Classification':
335
+ if temp_df[col].dtype == 'object' or temp_df[col].dtype == 'bool':
336
+ suitable_target_columns.append(col)
337
+ elif pd.api.types.is_integer_dtype(temp_df[col]) and temp_df[col].nunique() < 20: # Heuristic for discrete integers
338
+ suitable_target_columns.append(col)
339
+ # For regression, primarily numerical columns
340
+ elif prediction_type == 'Regression':
341
+ if pd.api.types.is_numeric_dtype(temp_df[col]):
342
+ suitable_target_columns.append(col)
343
+
344
+ if not suitable_target_columns:
345
+ st.warning(f"No suitable target columns found for {prediction_type}. Please check your data types.")
346
+ target_col = None # Set to None to prevent error if no columns are found
347
+ else:
348
+ # Try to pre-select the currently chosen target_col if it's still suitable
349
+ # Otherwise, default to the first suitable column
350
+ if 'target_col_select' in st.session_state and st.session_state.target_col_select in suitable_target_columns:
351
+ default_target_index = suitable_target_columns.index(st.session_state.target_col_select)
352
+ else:
353
+ default_target_index = 0
354
+
355
+ target_col = st.selectbox(
356
+ "Select Target Variable:",
357
+ suitable_target_columns,
358
+ index=default_target_index,
359
+ key="target_col_select"
360
+ )
361
+
362
+ # Filter out the target column from feature options
363
+ feature_options = [col for col in all_columns if col != target_col]
364
+ feature_cols = st.multiselect(
365
+ "Select Feature Variables:",
366
+ feature_options,
367
+ default=feature_options, # Default to all other columns
368
+ key="feature_cols_select"
369
+ )
370
+
371
+ if st.button("Train Model & Predict", key="train_predict_button"):
372
+ if not target_col or not feature_cols:
373
+ st.warning("Please select a target variable and at least one feature variable.")
374
+ else:
375
+ st.info(f"Training {prediction_type} model using Random Forest...")
376
+ try:
377
+ response = requests.post(
378
+ f"{FLASK_API_URL}/prediction/train_predict",
379
+ json={
380
+ "data": st.session_state.processed_data,
381
+ "target_col": target_col,
382
+ "feature_cols": feature_cols,
383
+ "prediction_type": prediction_type.lower()
384
+ }
385
+ )
386
+
387
+ if response.status_code == 200:
388
+ results = response.json()['results']
389
+ st.success(f"{prediction_type} Model Trained Successfully!")
390
+ st.subheader("Model Performance")
391
+
392
+ if prediction_type == 'Regression':
393
+ st.write(f"**R-squared:** {results['r2_score']:.4f}")
394
+ st.write(f"**Mean Squared Error (MSE):** {results['mean_squared_error']:.4f}")
395
+ st.write(f"**Root Mean Squared Error (RMSE):** {results['root_mean_squared_error']:.4f}")
396
+
397
+ st.subheader("Actual vs. Predicted Plot")
398
+ actual_predicted_df = pd.DataFrame(results['actual_vs_predicted'])
399
+ fig_reg = px.scatter(actual_predicted_df, x='Actual', y='Predicted',
400
+ title='Actual vs. Predicted Values',
401
+ labels={'Actual': f'Actual {target_col}', 'Predicted': f'Predicted {target_col}'})
402
+ fig_reg.add_trace(go.Scatter(x=[actual_predicted_df['Actual'].min(), actual_predicted_df['Actual'].max()],
403
+ y=[actual_predicted_df['Actual'].min(), actual_predicted_df['Actual'].max()],
404
+ mode='lines', name='Ideal Fit', line=dict(dash='dash', color='red')))
405
+ st.plotly_chart(fig_reg, use_container_width=True)
406
+
407
+ st.subheader("Residual Plot")
408
+ actual_predicted_df['Residuals'] = actual_predicted_df['Actual'] - actual_predicted_df['Predicted']
409
+ fig_res = px.scatter(actual_predicted_df, x='Predicted', y='Residuals',
410
+ title='Residual Plot',
411
+ labels={'Predicted': f'Predicted {target_col}', 'Residuals': 'Residuals'})
412
+ fig_res.add_hline(y=0, line_dash="dash", line_color="red")
413
+ st.plotly_chart(fig_res, use_container_width=True)
414
+
415
+ elif prediction_type == 'Classification':
416
+ st.write(f"**Accuracy:** {results['accuracy']:.4f}")
417
+ st.write(f"**Precision (weighted):** {results['precision']:.4f}")
418
+ st.write(f"**Recall (weighted):** {results['recall']:.4f}")
419
+ st.write(f"**F1-Score (weighted):** {results['f1_score']:.4f}")
420
+
421
+ st.subheader("Confusion Matrix")
422
+ conf_matrix = results['confusion_matrix']
423
+ class_labels = results.get('class_labels', [str(i) for i in range(len(conf_matrix))])
424
+ fig_cm = px.imshow(conf_matrix,
425
+ labels=dict(x="Predicted", y="True", color="Count"),
426
+ x=class_labels,
427
+ y=class_labels,
428
+ text_auto=True,
429
+ color_continuous_scale="Viridis",
430
+ title="Confusion Matrix")
431
+ st.plotly_chart(fig_cm, use_container_width=True)
432
+
433
+ st.subheader("Classification Report")
434
+ # Convert dict to DataFrame for nice display
435
+ report_df = pd.DataFrame(results['classification_report']).transpose()
436
+ st.dataframe(report_df)
437
+
438
+ st.subheader("Feature Importances")
439
+ feature_importances_df = pd.DataFrame(list(results['feature_importances'].items()), columns=['Feature', 'Importance'])
440
+ fig_fi = px.bar(feature_importances_df, x='Importance', y='Feature', orientation='h',
441
+ title='Feature Importances',
442
+ labels={'Importance': 'Importance Score', 'Feature': 'Feature Name'})
443
+ fig_fi.update_layout(yaxis={'categoryorder':'total ascending'}) # Sort bars
444
+ st.plotly_chart(fig_fi, use_container_width=True)
445
+ else:
446
+ st.error(f"Error during prediction: {response.json().get('detail', 'Unknown error')}")
447
+ except requests.exceptions.ConnectionError:
448
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
449
+ except Exception as e:
450
+ st.error(f"An unexpected error occurred during prediction: {e}")
451
+ else:
452
+ st.info("Please preprocess data first to use the Prediction Module.")
453
+
454
+ # --- Time Series Causal Discovery Module ---
455
+ st.header("7. Time Series Causal Discovery ⏰")
456
+ if st.session_state.processed_data:
457
+ st.write("Infer causal relationships in time-series data using Granger Causality.")
458
+ st.info("Ensure your dataset includes a timestamp column and that variables are numeric.")
459
+
460
+ all_columns = st.session_state.processed_columns
461
+
462
+ # Heuristic to suggest potential timestamp columns (object/string type, or first column)
463
+ potential_ts_cols = [col for col in all_columns if pd.DataFrame(st.session_state.processed_data)[col].dtype == 'object']
464
+ if not potential_ts_cols and all_columns: # If no object columns, suggest the first column
465
+ potential_ts_cols = [all_columns[0]]
466
+
467
+ timestamp_col = st.selectbox(
468
+ "Select Timestamp Column:",
469
+ potential_ts_cols if potential_ts_cols else ["No suitable timestamp column found. Please check data."],
470
+ key="ts_col_select"
471
+ )
472
+
473
+ # Filter out timestamp column and non-numeric columns for analysis
474
+ variables_for_ts_analysis = [
475
+ col for col in all_columns if col != timestamp_col and pd.api.types.is_numeric_dtype(pd.DataFrame(st.session_state.processed_data)[col])
476
+ ]
477
+
478
+ variables_to_analyze = st.multiselect(
479
+ "Select Variables to Analyze for Granger Causality:",
480
+ variables_for_ts_analysis,
481
+ default=variables_for_ts_analysis,
482
+ key="ts_vars_select"
483
+ )
484
+
485
+ max_lags = st.number_input(
486
+ "Max Lags (for Granger Causality):",
487
+ min_value=1,
488
+ value=5, # Default value
489
+ step=1,
490
+ help="The maximum number of lagged observations to consider for causality."
491
+ )
492
+
493
+ if st.button("Discover Time Series Causality", key="ts_discover_button"):
494
+ if not timestamp_col or not variables_to_analyze:
495
+ st.warning("Please select a timestamp column and at least one variable to analyze.")
496
+ elif "No suitable timestamp column found" in timestamp_col:
497
+ st.error("Cannot proceed. Please ensure your data has a suitable timestamp column.")
498
+ else:
499
+ st.info("Performing Granger Causality tests...")
500
+ try:
501
+ response = requests.post(
502
+ f"{FLASK_API_URL}/timeseries/discover_causality",
503
+ json={
504
+ "data": st.session_state.processed_data,
505
+ "timestamp_col": timestamp_col,
506
+ "variables_to_analyze": variables_to_analyze,
507
+ "max_lags": max_lags
508
+ }
509
+ )
510
+
511
+ if response.status_code == 200:
512
+ results = response.json()['results']
513
+ st.success("Time Series Causal Discovery Complete!")
514
+ st.subheader("Granger Causality Test Results")
515
+
516
+ if results:
517
+ # Convert results to a DataFrame for better display
518
+ results_df = pd.DataFrame(results)
519
+ results_df['p_value'] = results_df['p_value'].round(4) # Round p-values
520
+ st.dataframe(results_df)
521
+
522
+ st.markdown("**Interpretation:** A small p-value (typically < 0.05) suggests that the 'cause' variable Granger-causes the 'effect' variable. This means past values of the 'cause' variable help predict future values of the 'effect' variable, even when past values of the 'effect' variable are considered.")
523
+ st.markdown(f"*(Note: Granger Causality implies predictive causality, not necessarily true mechanistic causality. Also, ensure your time series are stationary for robust results.)*")
524
+
525
+ # Optionally, visualize a simple causality graph
526
+ st.subheader("Granger Causality Graph")
527
+ fig_ts_graph = go.Figure()
528
+ nodes = []
529
+ edges = []
530
+ edge_colors = []
531
+
532
+ # Add nodes
533
+ for i, var in enumerate(variables_to_analyze):
534
+ nodes.append(dict(id=var, label=var, x=np.cos(i*2*np.pi/len(variables_to_analyze)), y=np.sin(i*2*np.pi/len(variables_to_analyze))))
535
+
536
+ # Add edges
537
+ for res in results:
538
+ if res['p_value'] < 0.05: # Consider it a causal link if p-value is below significance
539
+ edges.append(dict(source=res['cause'], target=res['effect'], value=1/res['p_value'], title=f"p={res['p_value']:.4f}"))
540
+ edge_colors.append("blue")
541
+ else:
542
+ # Optional: Show non-significant edges in a different color or omit
543
+ pass
544
+
545
+ # Use a simple network graph layout (Spring layout is common)
546
+ # For a truly interactive graph, you might need a different library or more complex Plotly setup
547
+ # This is a very basic attempt to visualize; consider more robust solutions like NetworkX + Plotly/Dash
548
+
549
+ # Simple way to draw arrows for significant relationships
550
+ significant_edges = [edge for edge in results if edge['p_value'] < 0.05]
551
+ if significant_edges:
552
+ st.write("Visualizing significant (p < 0.05) Granger causal links:")
553
+ # This needs a more robust way to draw directed edges in plotly if using just scatter/lines.
554
+ # For now, let's just list them clearly.
555
+ for edge in significant_edges:
556
+ st.write(f"➡️ **{edge['cause']}** Granger-causes **{edge['effect']}** (p={edge['p_value']:.4f})")
557
+ else:
558
+ st.info("No significant Granger causal links found at p < 0.05.")
559
+
560
+ else:
561
+ st.info("No Granger Causality relationships found or data insufficient.")
562
+
563
+ else:
564
+ st.error(f"Error during time-series causal discovery: {response.json().get('detail', 'Unknown error')}")
565
+ except requests.exceptions.ConnectionError:
566
+ st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
567
+ except Exception as e:
568
+ st.error(f"An unexpected error occurred during time-series causal discovery: {e}")
569
+ else:
570
+ st.info("Please preprocess data first to use the Time Series Causal Discovery Module.")
571
+
572
+ # --- CausalBox Chat Assistant ---
573
+ st.header("8. CausalBox Chat Assistant 🤖")
574
+ st.write("Ask questions about your loaded dataset, causal concepts, or the discovered causal graph!")
575
+
576
+ # Initialize chat history in session state
577
+ if "messages" not in st.session_state:
578
+ st.session_state.messages = []
579
+
580
+ # Display chat messages from history on app rerun
581
+ for message in st.session_state.messages:
582
+ with st.chat_message(message["role"]):
583
+ st.markdown(message["content"])
584
+
585
+ # Accept user input
586
+ if prompt := st.chat_input("Ask me anything about CausalBox..."):
587
+ # Add user message to chat history
588
+ st.session_state.messages.append({"role": "user", "content": prompt})
589
+ # Display user message in chat message container
590
+ with st.chat_message("user"):
591
+ st.markdown(prompt)
592
+
593
+ # Prepare session context to send to the backend
594
+ session_context = {
595
+ "processed_data": st.session_state.processed_data,
596
+ "processed_columns": st.session_state.processed_columns,
597
+ "causal_graph_adj": st.session_state.causal_graph_adj,
598
+ "causal_graph_nodes": st.session_state.causal_graph_nodes,
599
+ # Add any other relevant session state variables that the chatbot might need
600
+ }
601
+
602
+ with st.spinner("Thinking..."):
603
+ try:
604
+ response = requests.post(
605
+ f"{FLASK_API_URL}/chatbot/message",
606
+ json={
607
+ "user_message": prompt,
608
+ "session_context": session_context
609
+ }
610
+ )
611
+
612
+ if response.status_code == 200:
613
+ chatbot_response_text = response.json().get('response', 'Sorry, I could not generate a response.')
614
+ else:
615
+ chatbot_response_text = f"Error from chatbot backend: {response.json().get('detail', 'Unknown error')}"
616
+ except requests.exceptions.ConnectionError:
617
+ chatbot_response_text = f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running."
618
+ except Exception as e:
619
+ chatbot_response_text = f"An unexpected error occurred while getting chatbot response: {e}"
620
+
621
+ # Display assistant response in chat message container
622
+ with st.chat_message("assistant"):
623
+ st.markdown(chatbot_response_text)
624
+ # Add assistant response to chat history
625
+ st.session_state.messages.append({"role": "assistant", "content": chatbot_response_text})
626
+
627
+ # --- Future Work (Simplified) ---
628
+ st.header("Future Work 🚀")
629
+ st.markdown("""
630
+ - **🔄 Auto-causal graph refresh:** Monitor dataset updates and automatically refresh the causal graph.
631
+ """)
632
+
633
+ st.markdown("---")
634
  st.info("Developed by CausalBox Team. For support, please contact us.")