Spaces:
Sleeping
Sleeping
Update streamlit_app.py
Browse files- streamlit_app.py +633 -618
streamlit_app.py
CHANGED
@@ -1,619 +1,634 @@
|
|
1 |
-
# streamlit_app.py
|
2 |
-
import streamlit as st
|
3 |
-
import pandas as pd
|
4 |
-
import requests
|
5 |
-
import json
|
6 |
-
import plotly.express as px
|
7 |
-
import plotly.graph_objects as go
|
8 |
-
import numpy as np # For random array in placeholders
|
9 |
-
import os
|
10 |
-
|
11 |
-
# Configuration
|
12 |
-
FLASK_API_URL = "http://localhost:5000" # Ensure this matches your Flask app's host and port
|
13 |
-
|
14 |
-
st.set_page_config(layout="wide", page_title="CausalBox Toolkit")
|
15 |
-
|
16 |
-
st.title("🔬 CausalBox: A Causal Inference Toolkit")
|
17 |
-
st.markdown("Uncover causal relationships, simulate interventions, and estimate treatment effects.")
|
18 |
-
|
19 |
-
# --- Session State Initialization ---
|
20 |
-
if 'processed_data' not in st.session_state:
|
21 |
-
st.session_state.processed_data = None
|
22 |
-
if 'processed_columns' not in st.session_state:
|
23 |
-
st.session_state.processed_columns = None
|
24 |
-
if 'causal_graph_adj' not in st.session_state:
|
25 |
-
st.session_state.causal_graph_adj = None
|
26 |
-
if 'causal_graph_nodes' not in st.session_state:
|
27 |
-
st.session_state.causal_graph_nodes = None
|
28 |
-
|
29 |
-
# --- Data Preprocessing Module ---
|
30 |
-
st.header("1. Data Preprocessor 🧹")
|
31 |
-
st.write("Upload your CSV dataset or use a generated sample dataset.")
|
32 |
-
|
33 |
-
# Option to use generated sample dataset
|
34 |
-
if st.button("Use Sample Dataset (sample_dataset.csv)"):
|
35 |
-
#
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
#
|
44 |
-
#
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
#
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
st.
|
63 |
-
|
64 |
-
|
65 |
-
st.error(f"
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
st.session_state.processed_data
|
79 |
-
|
80 |
-
st.
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
if
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
st.
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
st.
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
if
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
st.
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
"
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
"
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
if
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
)
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
st.
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
st.
|
417 |
-
|
418 |
-
st.
|
419 |
-
|
420 |
-
|
421 |
-
st.
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
#
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
#
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
st.
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
"
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
st.info("Developed by CausalBox Team. For support, please contact us.")
|
|
|
1 |
+
# streamlit_app.py
|
2 |
+
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
+
import requests
|
5 |
+
import json
|
6 |
+
import plotly.express as px
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
import numpy as np # For random array in placeholders
|
9 |
+
import os
|
10 |
+
import io
|
11 |
+
# Configuration
|
12 |
+
FLASK_API_URL = "http://localhost:5000" # Ensure this matches your Flask app's host and port
|
13 |
+
|
14 |
+
st.set_page_config(layout="wide", page_title="CausalBox Toolkit")
|
15 |
+
|
16 |
+
st.title("🔬 CausalBox: A Causal Inference Toolkit")
|
17 |
+
st.markdown("Uncover causal relationships, simulate interventions, and estimate treatment effects.")
|
18 |
+
|
19 |
+
# --- Session State Initialization ---
|
20 |
+
if 'processed_data' not in st.session_state:
|
21 |
+
st.session_state.processed_data = None
|
22 |
+
if 'processed_columns' not in st.session_state:
|
23 |
+
st.session_state.processed_columns = None
|
24 |
+
if 'causal_graph_adj' not in st.session_state:
|
25 |
+
st.session_state.causal_graph_adj = None
|
26 |
+
if 'causal_graph_nodes' not in st.session_state:
|
27 |
+
st.session_state.causal_graph_nodes = None
|
28 |
+
|
29 |
+
# --- Data Preprocessing Module ---
|
30 |
+
st.header("1. Data Preprocessor 🧹")
|
31 |
+
st.write("Upload your CSV dataset or use a generated sample dataset.")
|
32 |
+
|
33 |
+
# Option to use generated sample dataset
|
34 |
+
if st.button("Use Sample Dataset (sample_dataset.csv)"):
|
35 |
+
# Path to the sample_dataset.csv relative to streamlit_app.py
|
36 |
+
# Assumes sample_dataset.csv is in the 'data' folder at the root of the project
|
37 |
+
sample_csv_path = os.path.join(os.path.dirname(__file__), 'data', 'sample_dataset.csv')
|
38 |
+
|
39 |
+
if os.path.exists(sample_csv_path):
|
40 |
+
with open(sample_csv_path, 'rb') as f:
|
41 |
+
csv_content = f.read()
|
42 |
+
|
43 |
+
# Prepare the file for upload using 'files' parameter for multipart/form-data
|
44 |
+
# 'file' is the name of the input field Flask expects (request.files['file'])
|
45 |
+
# 'sample_dataset.csv' is the filename
|
46 |
+
# csv_content is the actual binary content of the file
|
47 |
+
# 'text/csv' is the content type
|
48 |
+
files = {'file': ('sample_dataset.csv', csv_content, 'text/csv')}
|
49 |
+
|
50 |
+
try:
|
51 |
+
# Send the file to Flask backend
|
52 |
+
response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
|
53 |
+
response.raise_for_status() # Raise an HTTPError for bad responses (4xx or 5xx)
|
54 |
+
processed_data_json = response.json()
|
55 |
+
|
56 |
+
# Update Streamlit session state with processed data and columns
|
57 |
+
st.session_state.processed_data = processed_data_json['data']
|
58 |
+
st.session_state.processed_columns = processed_data_json['columns']
|
59 |
+
st.success("Sample dataset loaded and preprocessed successfully!")
|
60 |
+
|
61 |
+
# Optional: Display the columns or a snippet of data for confirmation
|
62 |
+
st.json(processed_data_json['columns'])
|
63 |
+
|
64 |
+
except requests.exceptions.ConnectionError:
|
65 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
66 |
+
except requests.exceptions.HTTPError as http_err: # Catch HTTPError specifically for detailed error
|
67 |
+
st.error(f"HTTP error occurred: {http_err} - Server response: {http_err.response.text}")
|
68 |
+
except Exception as e:
|
69 |
+
st.error(f"An unexpected error occurred: {e}")
|
70 |
+
else:
|
71 |
+
st.error(f"Sample dataset not found at {sample_csv_path}. Please ensure it exists in your 'data' folder.")
|
72 |
+
|
73 |
+
if response.status_code == 200:
|
74 |
+
result = response.json()
|
75 |
+
st.session_state.processed_data = result['data']
|
76 |
+
st.session_state.processed_columns = result['columns']
|
77 |
+
st.success("Sample dataset preprocessed successfully!")
|
78 |
+
st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
|
79 |
+
else:
|
80 |
+
st.error(f"Error preprocessing sample dataset: {response.json().get('detail', 'Unknown error')}")
|
81 |
+
except Exception as e:
|
82 |
+
st.error(f"Could not load or process sample dataset: {e}")
|
83 |
+
|
84 |
+
|
85 |
+
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
|
86 |
+
if uploaded_file is not None:
|
87 |
+
st.info("Uploading and preprocessing data...")
|
88 |
+
files = {'file': (uploaded_file.name, uploaded_file.getvalue(), 'text/csv')}
|
89 |
+
try:
|
90 |
+
response = requests.post(f"{FLASK_API_URL}/preprocess/upload", files=files)
|
91 |
+
if response.status_code == 200:
|
92 |
+
result = response.json()
|
93 |
+
st.session_state.processed_data = result['data']
|
94 |
+
st.session_state.processed_columns = result['columns']
|
95 |
+
st.success("File preprocessed successfully!")
|
96 |
+
st.dataframe(pd.DataFrame(st.session_state.processed_data).head()) # Display first few rows
|
97 |
+
else:
|
98 |
+
st.error(f"Error during preprocessing: {response.json().get('detail', 'Unknown error')}")
|
99 |
+
except requests.exceptions.ConnectionError:
|
100 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
101 |
+
except Exception as e:
|
102 |
+
st.error(f"An unexpected error occurred: {e}")
|
103 |
+
|
104 |
+
# --- Causal Discovery Module ---
|
105 |
+
st.header("2. Causal Discovery 🕵️♂️")
|
106 |
+
if st.session_state.processed_data:
|
107 |
+
st.write("Learn the causal structure from your preprocessed data.")
|
108 |
+
|
109 |
+
discovery_algo = st.selectbox(
|
110 |
+
"Select Causal Discovery Algorithm:",
|
111 |
+
("PC Algorithm", "GES (Greedy Equivalence Search) - Placeholder", "NOTEARS - Placeholder")
|
112 |
+
)
|
113 |
+
|
114 |
+
if st.button("Discover Causal Graph"):
|
115 |
+
st.info(f"Discovering graph using {discovery_algo}...")
|
116 |
+
algo_map = {
|
117 |
+
"PC Algorithm": "pc",
|
118 |
+
"GES (Greedy Equivalence Search) - Placeholder": "ges",
|
119 |
+
"NOTEARS - Placeholder": "notears"
|
120 |
+
}
|
121 |
+
selected_algo_code = algo_map[discovery_algo]
|
122 |
+
|
123 |
+
try:
|
124 |
+
response = requests.post(
|
125 |
+
f"{FLASK_API_URL}/discover/",
|
126 |
+
json={"data": st.session_state.processed_data, "algorithm": selected_algo_code}
|
127 |
+
)
|
128 |
+
if response.status_code == 200:
|
129 |
+
result = response.json()
|
130 |
+
st.session_state.causal_graph_adj = result['graph']
|
131 |
+
st.session_state.causal_graph_nodes = st.session_state.processed_columns
|
132 |
+
st.success("Causal graph discovered!")
|
133 |
+
st.subheader("Causal Graph Visualization")
|
134 |
+
# Visualization will be handled by the Causal Graph Visualizer section
|
135 |
+
else:
|
136 |
+
st.error(f"Error during causal discovery: {response.json().get('detail', 'Unknown error')}")
|
137 |
+
except requests.exceptions.ConnectionError:
|
138 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
139 |
+
except Exception as e:
|
140 |
+
st.error(f"An unexpected error occurred: {e}")
|
141 |
+
else:
|
142 |
+
st.info("Please preprocess data first to enable causal discovery.")
|
143 |
+
|
144 |
+
# --- Causal Graph Visualizer Module ---
|
145 |
+
st.header("3. Causal Graph Visualizer 📊")
|
146 |
+
if st.session_state.causal_graph_adj and st.session_state.causal_graph_nodes:
|
147 |
+
st.write("Interactive visualization of the discovered causal graph.")
|
148 |
+
try:
|
149 |
+
response = requests.post(
|
150 |
+
f"{FLASK_API_URL}/visualize/graph",
|
151 |
+
json={"graph": st.session_state.causal_graph_adj, "nodes": st.session_state.causal_graph_nodes}
|
152 |
+
)
|
153 |
+
if response.status_code == 200:
|
154 |
+
graph_json = response.json()['graph']
|
155 |
+
fig = go.Figure(json.loads(graph_json))
|
156 |
+
st.plotly_chart(fig, use_container_width=True)
|
157 |
+
st.markdown("""
|
158 |
+
**Graph Explanation:**
|
159 |
+
* **Nodes:** Represent variables in your dataset.
|
160 |
+
* **Arrows (Edges):** Indicate a direct causal influence from one variable (the tail) to another (the head).
|
161 |
+
* **No Arrow:** Suggests no direct causal relationship was found, or the relationship is mediated by other variables.
|
162 |
+
|
163 |
+
This graph helps answer "Why did it happen?" by showing the structural relationships.
|
164 |
+
""")
|
165 |
+
else:
|
166 |
+
st.error(f"Error visualizing graph: {response.json().get('detail', 'Unknown error')}")
|
167 |
+
except requests.exceptions.ConnectionError:
|
168 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
169 |
+
except Exception as e:
|
170 |
+
st.error(f"An unexpected error occurred during visualization: {e}")
|
171 |
+
else:
|
172 |
+
st.info("Please discover a causal graph first to visualize it.")
|
173 |
+
|
174 |
+
|
175 |
+
# --- Do-Calculus Engine Module ---
|
176 |
+
st.header("4. Do-Calculus Engine 🧪")
|
177 |
+
if st.session_state.processed_data and st.session_state.causal_graph_adj:
|
178 |
+
st.write("Simulate interventions and observe their effects based on the causal graph.")
|
179 |
+
|
180 |
+
intervention_var = st.selectbox(
|
181 |
+
"Select variable to intervene on:",
|
182 |
+
st.session_state.processed_columns,
|
183 |
+
key="inter_var_select"
|
184 |
+
)
|
185 |
+
# Attempt to infer type for intervention_value input
|
186 |
+
# Simplified approach: assuming numerical for now due to preprocessor output
|
187 |
+
if intervention_var and isinstance(st.session_state.processed_data[0][intervention_var], (int, float)):
|
188 |
+
intervention_value = st.number_input(f"Set '{intervention_var}' to value:", key="inter_val_input")
|
189 |
+
else: # Treat as string/categorical for input, then try to preprocess for API
|
190 |
+
intervention_value = st.text_input(f"Set '{intervention_var}' to value:", key="inter_val_input_text")
|
191 |
+
st.warning("Categorical intervention values might require specific encoding logic on the backend.")
|
192 |
+
|
193 |
+
if st.button("Perform Intervention"):
|
194 |
+
st.info(f"Performing intervention: do('{intervention_var}' = {intervention_value})...")
|
195 |
+
try:
|
196 |
+
response = requests.post(
|
197 |
+
f"{FLASK_API_URL}/intervene/",
|
198 |
+
json={
|
199 |
+
"data": st.session_state.processed_data,
|
200 |
+
"intervention_var": intervention_var,
|
201 |
+
"intervention_value": intervention_value,
|
202 |
+
"graph": st.session_state.causal_graph_adj # Pass graph for advanced do-calculus
|
203 |
+
}
|
204 |
+
)
|
205 |
+
if response.status_code == 200:
|
206 |
+
intervened_data = pd.DataFrame(response.json()['intervened_data'])
|
207 |
+
st.success("Intervention simulated successfully!")
|
208 |
+
st.subheader("Intervened Data (First 10 rows)")
|
209 |
+
st.dataframe(intervened_data.head(10))
|
210 |
+
|
211 |
+
# Simple comparison visualization (e.g., histogram of outcome variable)
|
212 |
+
if st.session_state.processed_columns and 'FinalExamScore' in st.session_state.processed_columns:
|
213 |
+
original_df = pd.DataFrame(st.session_state.processed_data)
|
214 |
+
fig_dist = go.Figure()
|
215 |
+
fig_dist.add_trace(go.Histogram(x=original_df['FinalExamScore'], name='Original', opacity=0.7))
|
216 |
+
fig_dist.add_trace(go.Histogram(x=intervened_data['FinalExamScore'], name='Intervened', opacity=0.0))
|
217 |
+
|
218 |
+
st.plotly_chart(fig_dist, use_container_width=True)
|
219 |
+
st.markdown("""
|
220 |
+
**Intervention Explanation:**
|
221 |
+
* By simulating `do(X=x)`, we are forcing the value of X, effectively breaking its causal links from its parents.
|
222 |
+
* The graph above shows the distribution of a key outcome variable (e.g., `FinalExamScore`) before and after the intervention.
|
223 |
+
* This helps answer "What if we do this instead?" by showing the predicted outcome.
|
224 |
+
""")
|
225 |
+
else:
|
226 |
+
st.info("Consider adding a relevant outcome variable to your dataset for better intervention analysis.")
|
227 |
+
else:
|
228 |
+
st.error(f"Error during intervention: {response.json().get('detail', 'Unknown error')}")
|
229 |
+
except requests.exceptions.ConnectionError:
|
230 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
231 |
+
except Exception as e:
|
232 |
+
st.error(f"An unexpected error occurred during intervention: {e}")
|
233 |
+
else:
|
234 |
+
st.info("Please preprocess data and discover a causal graph first to perform interventions.")
|
235 |
+
|
236 |
+
# --- Treatment Effect Estimator Module ---
|
237 |
+
st.header("5. Treatment Effect Estimator 🎯")
|
238 |
+
if st.session_state.processed_data:
|
239 |
+
st.write("Estimate Average Treatment Effect (ATE) or Conditional Treatment Effect (CATE).")
|
240 |
+
|
241 |
+
col1, col2 = st.columns(2)
|
242 |
+
with col1:
|
243 |
+
treatment_col = st.selectbox(
|
244 |
+
"Select Treatment Variable:",
|
245 |
+
st.session_state.processed_columns,
|
246 |
+
key="treat_col_select"
|
247 |
+
)
|
248 |
+
with col2:
|
249 |
+
outcome_col = st.selectbox(
|
250 |
+
"Select Outcome Variable:",
|
251 |
+
st.session_state.processed_columns,
|
252 |
+
key="outcome_col_select"
|
253 |
+
)
|
254 |
+
|
255 |
+
all_cols_except_treat_outcome = [col for col in st.session_state.processed_columns if col not in [treatment_col, outcome_col]]
|
256 |
+
covariates = st.multiselect(
|
257 |
+
"Select Covariates (confounders):",
|
258 |
+
all_cols_except_treat_outcome,
|
259 |
+
default=all_cols_except_treat_outcome, # Default to all other columns
|
260 |
+
key="covariates_select"
|
261 |
+
)
|
262 |
+
|
263 |
+
estimation_method = st.selectbox(
|
264 |
+
"Select Estimation Method:",
|
265 |
+
(
|
266 |
+
"Linear Regression ATE",
|
267 |
+
"Propensity Score Matching - Placeholder",
|
268 |
+
"Inverse Propensity Weighting - Placeholder",
|
269 |
+
"T-learner - Placeholder",
|
270 |
+
"S-learner - Placeholder"
|
271 |
+
)
|
272 |
+
)
|
273 |
+
|
274 |
+
if st.button("Estimate Treatment Effect"):
|
275 |
+
st.info(f"Estimating treatment effect using {estimation_method}...")
|
276 |
+
method_map = {
|
277 |
+
"Linear Regression ATE": "linear_regression",
|
278 |
+
"Propensity Score Matching - Placeholder": "propensity_score_matching",
|
279 |
+
"Inverse Propensity Weighting - Placeholder": "inverse_propensity_weighting",
|
280 |
+
"T-learner - Placeholder": "t_learner",
|
281 |
+
"S-learner - Placeholder": "s_learner"
|
282 |
+
}
|
283 |
+
selected_method_code = method_map[estimation_method]
|
284 |
+
|
285 |
+
try:
|
286 |
+
response = requests.post(
|
287 |
+
f"{FLASK_API_URL}/treatment/estimate_ate",
|
288 |
+
json={
|
289 |
+
"data": st.session_state.processed_data,
|
290 |
+
"treatment_col": treatment_col,
|
291 |
+
"outcome_col": outcome_col,
|
292 |
+
"covariates": covariates,
|
293 |
+
"method": selected_method_code
|
294 |
+
}
|
295 |
+
)
|
296 |
+
if response.status_code == 200:
|
297 |
+
ate_result = response.json()['result']
|
298 |
+
st.success(f"Treatment effect estimated using {estimation_method}:")
|
299 |
+
st.write(f"**Estimated ATE: {ate_result:.4f}**")
|
300 |
+
st.markdown("""
|
301 |
+
**Treatment Effect Explanation:**
|
302 |
+
* **Average Treatment Effect (ATE):** Measures the average causal effect of a treatment (e.g., `StudyHours`) on an outcome (e.g., `FinalExamScore`) across the entire population.
|
303 |
+
* It answers "How much does doing X cause a change in Y?".
|
304 |
+
* This estimation attempts to control for confounders (variables that influence both treatment and outcome) to isolate the true causal effect.
|
305 |
+
""")
|
306 |
+
else:
|
307 |
+
st.error(f"Error during ATE estimation: {response.json().get('detail', 'Unknown error')}")
|
308 |
+
except requests.exceptions.ConnectionError:
|
309 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
310 |
+
except Exception as e:
|
311 |
+
st.error(f"An unexpected error occurred during ATE estimation: {e}")
|
312 |
+
else:
|
313 |
+
st.info("Please preprocess data first to estimate treatment effects.")
|
314 |
+
|
315 |
+
# --- Prediction Module ---
|
316 |
+
st.header("6. Prediction Module 📈")
|
317 |
+
if st.session_state.processed_data:
|
318 |
+
st.write("Train a machine learning model for prediction (Regression or Classification).")
|
319 |
+
|
320 |
+
prediction_type = st.selectbox(
|
321 |
+
"Select Prediction Type:",
|
322 |
+
("Regression", "Classification"),
|
323 |
+
key="prediction_type_select"
|
324 |
+
)
|
325 |
+
|
326 |
+
all_columns = st.session_state.processed_columns
|
327 |
+
|
328 |
+
suitable_target_columns = []
|
329 |
+
if st.session_state.processed_data:
|
330 |
+
temp_df = pd.DataFrame(st.session_state.processed_data)
|
331 |
+
for col in all_columns:
|
332 |
+
# For classification, check if column is object type (string), boolean,
|
333 |
+
# or has a limited number of unique integer values (e.g., less than 20 unique values)
|
334 |
+
if prediction_type == 'Classification':
|
335 |
+
if temp_df[col].dtype == 'object' or temp_df[col].dtype == 'bool':
|
336 |
+
suitable_target_columns.append(col)
|
337 |
+
elif pd.api.types.is_integer_dtype(temp_df[col]) and temp_df[col].nunique() < 20: # Heuristic for discrete integers
|
338 |
+
suitable_target_columns.append(col)
|
339 |
+
# For regression, primarily numerical columns
|
340 |
+
elif prediction_type == 'Regression':
|
341 |
+
if pd.api.types.is_numeric_dtype(temp_df[col]):
|
342 |
+
suitable_target_columns.append(col)
|
343 |
+
|
344 |
+
if not suitable_target_columns:
|
345 |
+
st.warning(f"No suitable target columns found for {prediction_type}. Please check your data types.")
|
346 |
+
target_col = None # Set to None to prevent error if no columns are found
|
347 |
+
else:
|
348 |
+
# Try to pre-select the currently chosen target_col if it's still suitable
|
349 |
+
# Otherwise, default to the first suitable column
|
350 |
+
if 'target_col_select' in st.session_state and st.session_state.target_col_select in suitable_target_columns:
|
351 |
+
default_target_index = suitable_target_columns.index(st.session_state.target_col_select)
|
352 |
+
else:
|
353 |
+
default_target_index = 0
|
354 |
+
|
355 |
+
target_col = st.selectbox(
|
356 |
+
"Select Target Variable:",
|
357 |
+
suitable_target_columns,
|
358 |
+
index=default_target_index,
|
359 |
+
key="target_col_select"
|
360 |
+
)
|
361 |
+
|
362 |
+
# Filter out the target column from feature options
|
363 |
+
feature_options = [col for col in all_columns if col != target_col]
|
364 |
+
feature_cols = st.multiselect(
|
365 |
+
"Select Feature Variables:",
|
366 |
+
feature_options,
|
367 |
+
default=feature_options, # Default to all other columns
|
368 |
+
key="feature_cols_select"
|
369 |
+
)
|
370 |
+
|
371 |
+
if st.button("Train Model & Predict", key="train_predict_button"):
|
372 |
+
if not target_col or not feature_cols:
|
373 |
+
st.warning("Please select a target variable and at least one feature variable.")
|
374 |
+
else:
|
375 |
+
st.info(f"Training {prediction_type} model using Random Forest...")
|
376 |
+
try:
|
377 |
+
response = requests.post(
|
378 |
+
f"{FLASK_API_URL}/prediction/train_predict",
|
379 |
+
json={
|
380 |
+
"data": st.session_state.processed_data,
|
381 |
+
"target_col": target_col,
|
382 |
+
"feature_cols": feature_cols,
|
383 |
+
"prediction_type": prediction_type.lower()
|
384 |
+
}
|
385 |
+
)
|
386 |
+
|
387 |
+
if response.status_code == 200:
|
388 |
+
results = response.json()['results']
|
389 |
+
st.success(f"{prediction_type} Model Trained Successfully!")
|
390 |
+
st.subheader("Model Performance")
|
391 |
+
|
392 |
+
if prediction_type == 'Regression':
|
393 |
+
st.write(f"**R-squared:** {results['r2_score']:.4f}")
|
394 |
+
st.write(f"**Mean Squared Error (MSE):** {results['mean_squared_error']:.4f}")
|
395 |
+
st.write(f"**Root Mean Squared Error (RMSE):** {results['root_mean_squared_error']:.4f}")
|
396 |
+
|
397 |
+
st.subheader("Actual vs. Predicted Plot")
|
398 |
+
actual_predicted_df = pd.DataFrame(results['actual_vs_predicted'])
|
399 |
+
fig_reg = px.scatter(actual_predicted_df, x='Actual', y='Predicted',
|
400 |
+
title='Actual vs. Predicted Values',
|
401 |
+
labels={'Actual': f'Actual {target_col}', 'Predicted': f'Predicted {target_col}'})
|
402 |
+
fig_reg.add_trace(go.Scatter(x=[actual_predicted_df['Actual'].min(), actual_predicted_df['Actual'].max()],
|
403 |
+
y=[actual_predicted_df['Actual'].min(), actual_predicted_df['Actual'].max()],
|
404 |
+
mode='lines', name='Ideal Fit', line=dict(dash='dash', color='red')))
|
405 |
+
st.plotly_chart(fig_reg, use_container_width=True)
|
406 |
+
|
407 |
+
st.subheader("Residual Plot")
|
408 |
+
actual_predicted_df['Residuals'] = actual_predicted_df['Actual'] - actual_predicted_df['Predicted']
|
409 |
+
fig_res = px.scatter(actual_predicted_df, x='Predicted', y='Residuals',
|
410 |
+
title='Residual Plot',
|
411 |
+
labels={'Predicted': f'Predicted {target_col}', 'Residuals': 'Residuals'})
|
412 |
+
fig_res.add_hline(y=0, line_dash="dash", line_color="red")
|
413 |
+
st.plotly_chart(fig_res, use_container_width=True)
|
414 |
+
|
415 |
+
elif prediction_type == 'Classification':
|
416 |
+
st.write(f"**Accuracy:** {results['accuracy']:.4f}")
|
417 |
+
st.write(f"**Precision (weighted):** {results['precision']:.4f}")
|
418 |
+
st.write(f"**Recall (weighted):** {results['recall']:.4f}")
|
419 |
+
st.write(f"**F1-Score (weighted):** {results['f1_score']:.4f}")
|
420 |
+
|
421 |
+
st.subheader("Confusion Matrix")
|
422 |
+
conf_matrix = results['confusion_matrix']
|
423 |
+
class_labels = results.get('class_labels', [str(i) for i in range(len(conf_matrix))])
|
424 |
+
fig_cm = px.imshow(conf_matrix,
|
425 |
+
labels=dict(x="Predicted", y="True", color="Count"),
|
426 |
+
x=class_labels,
|
427 |
+
y=class_labels,
|
428 |
+
text_auto=True,
|
429 |
+
color_continuous_scale="Viridis",
|
430 |
+
title="Confusion Matrix")
|
431 |
+
st.plotly_chart(fig_cm, use_container_width=True)
|
432 |
+
|
433 |
+
st.subheader("Classification Report")
|
434 |
+
# Convert dict to DataFrame for nice display
|
435 |
+
report_df = pd.DataFrame(results['classification_report']).transpose()
|
436 |
+
st.dataframe(report_df)
|
437 |
+
|
438 |
+
st.subheader("Feature Importances")
|
439 |
+
feature_importances_df = pd.DataFrame(list(results['feature_importances'].items()), columns=['Feature', 'Importance'])
|
440 |
+
fig_fi = px.bar(feature_importances_df, x='Importance', y='Feature', orientation='h',
|
441 |
+
title='Feature Importances',
|
442 |
+
labels={'Importance': 'Importance Score', 'Feature': 'Feature Name'})
|
443 |
+
fig_fi.update_layout(yaxis={'categoryorder':'total ascending'}) # Sort bars
|
444 |
+
st.plotly_chart(fig_fi, use_container_width=True)
|
445 |
+
else:
|
446 |
+
st.error(f"Error during prediction: {response.json().get('detail', 'Unknown error')}")
|
447 |
+
except requests.exceptions.ConnectionError:
|
448 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
449 |
+
except Exception as e:
|
450 |
+
st.error(f"An unexpected error occurred during prediction: {e}")
|
451 |
+
else:
|
452 |
+
st.info("Please preprocess data first to use the Prediction Module.")
|
453 |
+
|
454 |
+
# --- Time Series Causal Discovery Module ---
|
455 |
+
st.header("7. Time Series Causal Discovery ⏰")
|
456 |
+
if st.session_state.processed_data:
|
457 |
+
st.write("Infer causal relationships in time-series data using Granger Causality.")
|
458 |
+
st.info("Ensure your dataset includes a timestamp column and that variables are numeric.")
|
459 |
+
|
460 |
+
all_columns = st.session_state.processed_columns
|
461 |
+
|
462 |
+
# Heuristic to suggest potential timestamp columns (object/string type, or first column)
|
463 |
+
potential_ts_cols = [col for col in all_columns if pd.DataFrame(st.session_state.processed_data)[col].dtype == 'object']
|
464 |
+
if not potential_ts_cols and all_columns: # If no object columns, suggest the first column
|
465 |
+
potential_ts_cols = [all_columns[0]]
|
466 |
+
|
467 |
+
timestamp_col = st.selectbox(
|
468 |
+
"Select Timestamp Column:",
|
469 |
+
potential_ts_cols if potential_ts_cols else ["No suitable timestamp column found. Please check data."],
|
470 |
+
key="ts_col_select"
|
471 |
+
)
|
472 |
+
|
473 |
+
# Filter out timestamp column and non-numeric columns for analysis
|
474 |
+
variables_for_ts_analysis = [
|
475 |
+
col for col in all_columns if col != timestamp_col and pd.api.types.is_numeric_dtype(pd.DataFrame(st.session_state.processed_data)[col])
|
476 |
+
]
|
477 |
+
|
478 |
+
variables_to_analyze = st.multiselect(
|
479 |
+
"Select Variables to Analyze for Granger Causality:",
|
480 |
+
variables_for_ts_analysis,
|
481 |
+
default=variables_for_ts_analysis,
|
482 |
+
key="ts_vars_select"
|
483 |
+
)
|
484 |
+
|
485 |
+
max_lags = st.number_input(
|
486 |
+
"Max Lags (for Granger Causality):",
|
487 |
+
min_value=1,
|
488 |
+
value=5, # Default value
|
489 |
+
step=1,
|
490 |
+
help="The maximum number of lagged observations to consider for causality."
|
491 |
+
)
|
492 |
+
|
493 |
+
if st.button("Discover Time Series Causality", key="ts_discover_button"):
|
494 |
+
if not timestamp_col or not variables_to_analyze:
|
495 |
+
st.warning("Please select a timestamp column and at least one variable to analyze.")
|
496 |
+
elif "No suitable timestamp column found" in timestamp_col:
|
497 |
+
st.error("Cannot proceed. Please ensure your data has a suitable timestamp column.")
|
498 |
+
else:
|
499 |
+
st.info("Performing Granger Causality tests...")
|
500 |
+
try:
|
501 |
+
response = requests.post(
|
502 |
+
f"{FLASK_API_URL}/timeseries/discover_causality",
|
503 |
+
json={
|
504 |
+
"data": st.session_state.processed_data,
|
505 |
+
"timestamp_col": timestamp_col,
|
506 |
+
"variables_to_analyze": variables_to_analyze,
|
507 |
+
"max_lags": max_lags
|
508 |
+
}
|
509 |
+
)
|
510 |
+
|
511 |
+
if response.status_code == 200:
|
512 |
+
results = response.json()['results']
|
513 |
+
st.success("Time Series Causal Discovery Complete!")
|
514 |
+
st.subheader("Granger Causality Test Results")
|
515 |
+
|
516 |
+
if results:
|
517 |
+
# Convert results to a DataFrame for better display
|
518 |
+
results_df = pd.DataFrame(results)
|
519 |
+
results_df['p_value'] = results_df['p_value'].round(4) # Round p-values
|
520 |
+
st.dataframe(results_df)
|
521 |
+
|
522 |
+
st.markdown("**Interpretation:** A small p-value (typically < 0.05) suggests that the 'cause' variable Granger-causes the 'effect' variable. This means past values of the 'cause' variable help predict future values of the 'effect' variable, even when past values of the 'effect' variable are considered.")
|
523 |
+
st.markdown(f"*(Note: Granger Causality implies predictive causality, not necessarily true mechanistic causality. Also, ensure your time series are stationary for robust results.)*")
|
524 |
+
|
525 |
+
# Optionally, visualize a simple causality graph
|
526 |
+
st.subheader("Granger Causality Graph")
|
527 |
+
fig_ts_graph = go.Figure()
|
528 |
+
nodes = []
|
529 |
+
edges = []
|
530 |
+
edge_colors = []
|
531 |
+
|
532 |
+
# Add nodes
|
533 |
+
for i, var in enumerate(variables_to_analyze):
|
534 |
+
nodes.append(dict(id=var, label=var, x=np.cos(i*2*np.pi/len(variables_to_analyze)), y=np.sin(i*2*np.pi/len(variables_to_analyze))))
|
535 |
+
|
536 |
+
# Add edges
|
537 |
+
for res in results:
|
538 |
+
if res['p_value'] < 0.05: # Consider it a causal link if p-value is below significance
|
539 |
+
edges.append(dict(source=res['cause'], target=res['effect'], value=1/res['p_value'], title=f"p={res['p_value']:.4f}"))
|
540 |
+
edge_colors.append("blue")
|
541 |
+
else:
|
542 |
+
# Optional: Show non-significant edges in a different color or omit
|
543 |
+
pass
|
544 |
+
|
545 |
+
# Use a simple network graph layout (Spring layout is common)
|
546 |
+
# For a truly interactive graph, you might need a different library or more complex Plotly setup
|
547 |
+
# This is a very basic attempt to visualize; consider more robust solutions like NetworkX + Plotly/Dash
|
548 |
+
|
549 |
+
# Simple way to draw arrows for significant relationships
|
550 |
+
significant_edges = [edge for edge in results if edge['p_value'] < 0.05]
|
551 |
+
if significant_edges:
|
552 |
+
st.write("Visualizing significant (p < 0.05) Granger causal links:")
|
553 |
+
# This needs a more robust way to draw directed edges in plotly if using just scatter/lines.
|
554 |
+
# For now, let's just list them clearly.
|
555 |
+
for edge in significant_edges:
|
556 |
+
st.write(f"➡️ **{edge['cause']}** Granger-causes **{edge['effect']}** (p={edge['p_value']:.4f})")
|
557 |
+
else:
|
558 |
+
st.info("No significant Granger causal links found at p < 0.05.")
|
559 |
+
|
560 |
+
else:
|
561 |
+
st.info("No Granger Causality relationships found or data insufficient.")
|
562 |
+
|
563 |
+
else:
|
564 |
+
st.error(f"Error during time-series causal discovery: {response.json().get('detail', 'Unknown error')}")
|
565 |
+
except requests.exceptions.ConnectionError:
|
566 |
+
st.error(f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running.")
|
567 |
+
except Exception as e:
|
568 |
+
st.error(f"An unexpected error occurred during time-series causal discovery: {e}")
|
569 |
+
else:
|
570 |
+
st.info("Please preprocess data first to use the Time Series Causal Discovery Module.")
|
571 |
+
|
572 |
+
# --- CausalBox Chat Assistant ---
|
573 |
+
st.header("8. CausalBox Chat Assistant 🤖")
|
574 |
+
st.write("Ask questions about your loaded dataset, causal concepts, or the discovered causal graph!")
|
575 |
+
|
576 |
+
# Initialize chat history in session state
|
577 |
+
if "messages" not in st.session_state:
|
578 |
+
st.session_state.messages = []
|
579 |
+
|
580 |
+
# Display chat messages from history on app rerun
|
581 |
+
for message in st.session_state.messages:
|
582 |
+
with st.chat_message(message["role"]):
|
583 |
+
st.markdown(message["content"])
|
584 |
+
|
585 |
+
# Accept user input
|
586 |
+
if prompt := st.chat_input("Ask me anything about CausalBox..."):
|
587 |
+
# Add user message to chat history
|
588 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
589 |
+
# Display user message in chat message container
|
590 |
+
with st.chat_message("user"):
|
591 |
+
st.markdown(prompt)
|
592 |
+
|
593 |
+
# Prepare session context to send to the backend
|
594 |
+
session_context = {
|
595 |
+
"processed_data": st.session_state.processed_data,
|
596 |
+
"processed_columns": st.session_state.processed_columns,
|
597 |
+
"causal_graph_adj": st.session_state.causal_graph_adj,
|
598 |
+
"causal_graph_nodes": st.session_state.causal_graph_nodes,
|
599 |
+
# Add any other relevant session state variables that the chatbot might need
|
600 |
+
}
|
601 |
+
|
602 |
+
with st.spinner("Thinking..."):
|
603 |
+
try:
|
604 |
+
response = requests.post(
|
605 |
+
f"{FLASK_API_URL}/chatbot/message",
|
606 |
+
json={
|
607 |
+
"user_message": prompt,
|
608 |
+
"session_context": session_context
|
609 |
+
}
|
610 |
+
)
|
611 |
+
|
612 |
+
if response.status_code == 200:
|
613 |
+
chatbot_response_text = response.json().get('response', 'Sorry, I could not generate a response.')
|
614 |
+
else:
|
615 |
+
chatbot_response_text = f"Error from chatbot backend: {response.json().get('detail', 'Unknown error')}"
|
616 |
+
except requests.exceptions.ConnectionError:
|
617 |
+
chatbot_response_text = f"Could not connect to Flask API at {FLASK_API_URL}. Please ensure the backend is running."
|
618 |
+
except Exception as e:
|
619 |
+
chatbot_response_text = f"An unexpected error occurred while getting chatbot response: {e}"
|
620 |
+
|
621 |
+
# Display assistant response in chat message container
|
622 |
+
with st.chat_message("assistant"):
|
623 |
+
st.markdown(chatbot_response_text)
|
624 |
+
# Add assistant response to chat history
|
625 |
+
st.session_state.messages.append({"role": "assistant", "content": chatbot_response_text})
|
626 |
+
|
627 |
+
# --- Future Work (Simplified) ---
|
628 |
+
st.header("Future Work 🚀")
|
629 |
+
st.markdown("""
|
630 |
+
- **🔄 Auto-causal graph refresh:** Monitor dataset updates and automatically refresh the causal graph.
|
631 |
+
""")
|
632 |
+
|
633 |
+
st.markdown("---")
|
634 |
st.info("Developed by CausalBox Team. For support, please contact us.")
|