Thiresh commited on
Commit
f64df95
·
verified ·
1 Parent(s): ef153c9

Update data_analysis_agent.py

Browse files
Files changed (1) hide show
  1. data_analysis_agent.py +595 -595
data_analysis_agent.py CHANGED
@@ -1,596 +1,596 @@
1
- import os, io, re
2
- import pandas as pd
3
- import numpy as np
4
- import streamlit as st
5
- from openai import OpenAI
6
- import matplotlib.pyplot as plt
7
- from typing import List, Any, Optional
8
-
9
- # === Configuration ===
10
- # Global configuration
11
- API_BASE_URL = "https://integrate.api.nvidia.com/v1"
12
- API_KEY = "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" #os.environ.get("NVIDIA_API_KEY")
13
-
14
- # Plot configuration
15
- DEFAULT_FIGSIZE = (6, 4)
16
- DEFAULT_DPI = 100
17
-
18
- # Display configuration
19
- MAX_RESULT_DISPLAY_LENGTH = 300
20
-
21
- class ModelConfig:
22
- """Configuration class for different models."""
23
-
24
- def __init__(self, model_name: str, model_url: str, model_print_name: str,
25
- # QueryUnderstandingTool parameters
26
- query_understanding_temperature: float = 0.1,
27
- query_understanding_max_tokens: int = 5,
28
- # CodeGenerationAgent parameters
29
- code_generation_temperature: float = 0.2,
30
- code_generation_max_tokens: int = 1024,
31
- # ReasoningAgent parameters
32
- reasoning_temperature: float = 0.2,
33
- reasoning_max_tokens: int = 1024,
34
- # DataInsightAgent parameters
35
- insights_temperature: float = 0.2,
36
- insights_max_tokens: int = 512,
37
- reasoning_false: str = "detailed thinking off",
38
- reasoning_true: str = "detailed thinking on"):
39
- self.MODEL_NAME = model_name
40
- self.MODEL_URL = model_url
41
- self.MODEL_PRINT_NAME = model_print_name
42
-
43
- # Function-specific LLM parameters
44
- self.QUERY_UNDERSTANDING_TEMPERATURE = query_understanding_temperature
45
- self.QUERY_UNDERSTANDING_MAX_TOKENS = query_understanding_max_tokens
46
- self.CODE_GENERATION_TEMPERATURE = code_generation_temperature
47
- self.CODE_GENERATION_MAX_TOKENS = code_generation_max_tokens
48
- self.REASONING_TEMPERATURE = reasoning_temperature
49
- self.REASONING_MAX_TOKENS = reasoning_max_tokens
50
- self.INSIGHTS_TEMPERATURE = insights_temperature
51
- self.INSIGHTS_MAX_TOKENS = insights_max_tokens
52
- self.REASONING_FALSE = reasoning_false
53
- self.REASONING_TRUE = reasoning_true
54
-
55
- # Predefined model configurations
56
- MODEL_CONFIGS = {
57
- "llama-3-1-nemotron-ultra-v1": ModelConfig(
58
- model_name="nvidia/llama-3.1-nemotron-ultra-253b-v1",
59
- model_url="https://build.nvidia.com/nvidia/llama-3_1-nemotron-ultra-253b-v1",
60
- model_print_name="NVIDIA Llama 3.1 Nemotron Ultra 253B v1",
61
- # QueryUnderstandingTool
62
- query_understanding_temperature=0.1,
63
- query_understanding_max_tokens=5,
64
- # CodeGenerationAgent
65
- code_generation_temperature=0.2,
66
- code_generation_max_tokens=1024,
67
- # ReasoningAgent
68
- reasoning_temperature=0.6,
69
- reasoning_max_tokens=1024,
70
- # DataInsightAgent
71
- insights_temperature=0.2,
72
- insights_max_tokens=512,
73
- reasoning_false="detailed thinking off",
74
- reasoning_true="detailed thinking on"
75
- ),
76
- "llama-3-3-nemotron-super-v1-5": ModelConfig(
77
- model_name="nvidia/llama-3.3-nemotron-super-49b-v1.5",
78
- model_url="https://build.nvidia.com/nvidia/llama-3_3-nemotron-super-49b-v1_5",
79
- model_print_name="NVIDIA Llama 3.3 Nemotron Super 49B v1.5",
80
- # QueryUnderstandingTool
81
- query_understanding_temperature=0.1,
82
- query_understanding_max_tokens=5,
83
- # CodeGenerationAgent
84
- code_generation_temperature=0.0,
85
- code_generation_max_tokens=1024,
86
- # ReasoningAgent
87
- reasoning_temperature=0.6,
88
- reasoning_max_tokens=2048,
89
- # DataInsightAgent
90
- insights_temperature=0.2,
91
- insights_max_tokens=512,
92
- reasoning_false="/no_think",
93
- reasoning_true=""
94
- )
95
- }
96
-
97
- # Default configuration (can be changed via environment variable or UI)
98
- DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "llama-3-1-nemotron-ultra-v1")
99
- Config = MODEL_CONFIGS.get(DEFAULT_MODEL, MODEL_CONFIGS["llama-3-1-nemotron-ultra-v1"])
100
-
101
- # Initialize OpenAI client with configuration
102
- client = OpenAI(
103
- base_url=API_BASE_URL,
104
- api_key=API_KEY
105
- )
106
-
107
- def get_current_config():
108
- """Get the current model configuration based on session state."""
109
- # Always return the current model from session state
110
- if "current_model" in st.session_state:
111
- return MODEL_CONFIGS[st.session_state.current_model]
112
-
113
- return MODEL_CONFIGS[DEFAULT_MODEL]
114
-
115
- # ------------------ QueryUnderstandingTool ---------------------------
116
- def QueryUnderstandingTool(query: str) -> bool:
117
- """Return True if the query seems to request a visualisation based on keywords."""
118
- # Use LLM to understand intent instead of keyword matching
119
- current_config = get_current_config()
120
-
121
- # Prepend the instruction to the query
122
- full_prompt = f"""You are a query classifier. Your task is to determine if a user query is requesting a data visualization.
123
-
124
- IMPORTANT: Respond with ONLY 'true' or 'false' (lowercase, no quotes, no punctuation).
125
-
126
- Classify as 'true' ONLY if the query explicitly asks for:
127
- - A plot, chart, graph, visualization, or figure
128
- - To "show" or "display" data visually
129
- - To "create" or "generate" a visual representation
130
- - Words like: plot, chart, graph, visualize, show, display, create, generate, draw
131
-
132
- Classify as 'false' for:
133
- - Data analysis without visualization requests
134
- - Statistical calculations, aggregations, filtering, sorting
135
- - Questions about data content, counts, summaries
136
- - Requests for tables, dataframes, or text results
137
-
138
- User query: {query}"""
139
-
140
- messages = [
141
- {"role": "system", "content": current_config.REASONING_FALSE},
142
- {"role": "user", "content": full_prompt}
143
- ]
144
-
145
- response = client.chat.completions.create(
146
- model=current_config.MODEL_NAME,
147
- messages=messages,
148
- temperature=current_config.QUERY_UNDERSTANDING_TEMPERATURE,
149
- max_tokens=current_config.QUERY_UNDERSTANDING_MAX_TOKENS # We only need a short response
150
- )
151
-
152
- # Extract the response and convert to boolean
153
-
154
- intent_response = response.choices[0].message.content.strip().lower()
155
-
156
- return intent_response == "true"
157
-
158
- # === CodeGeneration TOOLS ============================================
159
-
160
-
161
- # ------------------ CodeWritingTool ---------------------------------
162
- def CodeWritingTool(cols: List[str], query: str) -> str:
163
- """Generate a prompt for the LLM to write pandas-only code for a data query (no plotting)."""
164
-
165
- return f"""
166
-
167
- Given DataFrame `df` with columns:
168
-
169
- {', '.join(cols)}
170
-
171
- Write Python code (pandas **only**, no plotting) to answer:
172
- "{query}"
173
-
174
- Rules
175
- -----
176
- 1. Use pandas operations on `df` only.
177
- 2. Rely only on the columns in the DataFrame.
178
- 3. Assign the final result to `result`.
179
- 4. Return your answer inside a single markdown fence that starts with ```python and ends with ```.
180
- 5. Do not include any explanations, comments, or prose outside the code block.
181
- 6. Use **df** as the sole data source. **Do not** read files, fetch data, or use Streamlit.
182
- 7. Do **not** import any libraries (pandas is already imported as pd).
183
- 8. Handle missing values (`dropna`) before aggregations.
184
-
185
- Example
186
- -----
187
- ```python
188
- result = df.groupby("some_column")["a_numeric_col"].mean().sort_values(ascending=False)
189
- ```
190
-
191
- """
192
-
193
-
194
- # ------------------ PlotCodeGeneratorTool ---------------------------
195
- def PlotCodeGeneratorTool(cols: List[str], query: str) -> str:
196
-
197
- """Generate a prompt for the LLM to write pandas + matplotlib code for a plot based on the query and columns."""
198
-
199
- return f"""
200
-
201
- Given DataFrame `df` with columns:
202
-
203
- {', '.join(cols)}
204
-
205
- Write Python code using pandas **and matplotlib** (as plt) to answer:
206
- "{query}"
207
-
208
- Rules
209
- -----
210
- 1. Use pandas for data manipulation and matplotlib.pyplot (as plt) for plotting.
211
- 2. Rely only on the columns in the DataFrame.
212
- 3. Assign the final result (DataFrame, Series, scalar *or* matplotlib Figure) to a variable named `result`.
213
- 4. Create only ONE relevant plot. Set `figsize={DEFAULT_FIGSIZE}`, add title/labels.
214
- 5. Return your answer inside a single markdown fence that starts with ```python and ends with ```.
215
- 6. Do not include any explanations, comments, or prose outside the code block.
216
- 7. Handle missing values (`dropna`) before plotting/aggregations.
217
-
218
- """
219
-
220
-
221
- # === CodeGenerationAgent ==============================================
222
-
223
- def CodeGenerationAgent(query: str, df: pd.DataFrame, chat_context: Optional[str] = None):
224
- """Selects the appropriate code generation tool and gets code from the LLM for the user's query."""
225
-
226
- should_plot = QueryUnderstandingTool(query)
227
-
228
- prompt = PlotCodeGeneratorTool(df.columns.tolist(), query) if should_plot else CodeWritingTool(df.columns.tolist(), query)
229
-
230
- # Prepend the instruction to the query
231
- context_section = f"\nConversation context (recent user turns):\n{chat_context}\n" if chat_context else ""
232
-
233
- full_prompt = f"""You are a senior Python data analyst who writes clean, efficient code.
234
- Solve the given problem with optimal pandas operations. Be concise and focused.
235
- Your response must contain ONLY a properly-closed ```python code block with no explanations before or after (starts with ```python and ends with ```).
236
- Ensure your solution is correct, handles edge cases, and follows best practices for data analysis.
237
- If the latest user request references prior results ambiguously (e.g., "it", "that", "same groups"), infer intent from the conversation context and choose the most reasonable interpretation. {context_section}{prompt}"""
238
-
239
- current_config = get_current_config()
240
-
241
- messages = [
242
- {"role": "system", "content": current_config.REASONING_FALSE},
243
- {"role": "user", "content": full_prompt}
244
- ]
245
-
246
- response = client.chat.completions.create(
247
- model=current_config.MODEL_NAME,
248
- messages=messages,
249
- temperature=current_config.CODE_GENERATION_TEMPERATURE,
250
- max_tokens=current_config.CODE_GENERATION_MAX_TOKENS
251
- )
252
-
253
- full_response = response.choices[0].message.content
254
-
255
- code = extract_first_code_block(full_response)
256
- return code, should_plot, ""
257
-
258
- # === ExecutionAgent ====================================================
259
-
260
- def ExecutionAgent(code: str, df: pd.DataFrame, should_plot: bool):
261
- """Executes the generated code in a controlled environment and returns the result or error message."""
262
-
263
- # Set up execution environment with all necessary modules
264
- env = {
265
- "pd": pd,
266
- "df": df
267
- }
268
-
269
- if should_plot:
270
- plt.rcParams["figure.dpi"] = DEFAULT_DPI # Set default DPI for all figures
271
- env["plt"] = plt
272
- env["io"] = io
273
-
274
- try:
275
- # Execute the code in the environment
276
- exec(code, {}, env)
277
- result = env.get("result", None)
278
-
279
- # If no result was assigned, return the last expression
280
- if result is None:
281
- # Try to get the last executed expression
282
- if "result" not in env:
283
- return "No result was assigned to 'result' variable"
284
-
285
- return result
286
- except Exception as exc:
287
- return f"Error executing code: {exc}"
288
-
289
- # === ReasoningCurator TOOL =========================================
290
- def ReasoningCurator(query: str, result: Any) -> str:
291
- """Builds and returns the LLM prompt for reasoning about the result."""
292
- is_error = isinstance(result, str) and result.startswith("Error executing code")
293
- is_plot = isinstance(result, (plt.Figure, plt.Axes))
294
-
295
- if is_error:
296
- desc = result
297
- elif is_plot:
298
- title = ""
299
- if isinstance(result, plt.Figure):
300
- title = result._suptitle.get_text() if result._suptitle else ""
301
- elif isinstance(result, plt.Axes):
302
- title = result.get_title()
303
- desc = f"[Plot Object: {title or 'Chart'}]"
304
- else:
305
- desc = str(result)[:MAX_RESULT_DISPLAY_LENGTH]
306
-
307
- if is_plot:
308
- prompt = f'''
309
- The user asked: "{query}".
310
- Below is a description of the plot result:
311
- {desc}
312
- Explain in 2–3 concise sentences what the chart shows (no code talk).'''
313
- else:
314
- prompt = f'''
315
- The user asked: "{query}".
316
- The result value is: {desc}
317
- Explain in 2–3 concise sentences what this tells about the data (no mention of charts).'''
318
- return prompt
319
-
320
- # === ReasoningAgent (streaming) =========================================
321
- def ReasoningAgent(query: str, result: Any):
322
- """Streams the LLM's reasoning about the result (plot or value) and extracts model 'thinking' and final explanation."""
323
- current_config = get_current_config()
324
- prompt = ReasoningCurator(query, result)
325
-
326
- # Streaming LLM call
327
- response = client.chat.completions.create(
328
- model=current_config.MODEL_NAME,
329
- messages=[
330
- {"role": "system", "content": current_config.REASONING_TRUE},
331
- {"role": "user", "content": "You are an insightful data analyst. " + prompt}
332
- ],
333
- temperature=current_config.REASONING_TEMPERATURE,
334
- max_tokens=current_config.REASONING_MAX_TOKENS,
335
- stream=True
336
- )
337
-
338
- # Stream and display thinking
339
- thinking_placeholder = st.empty()
340
- full_response = ""
341
- thinking_content = ""
342
- in_think = False
343
-
344
- for chunk in response:
345
- if chunk.choices[0].delta.content is not None:
346
- token = chunk.choices[0].delta.content
347
- full_response += token
348
-
349
- # Simple state machine to extract <think>...</think> as it streams
350
- if "<think>" in token:
351
- in_think = True
352
- token = token.split("<think>", 1)[1]
353
- if "</think>" in token:
354
- token = token.split("</think>", 1)[0]
355
- in_think = False
356
- if in_think or ("<think>" in full_response and not "</think>" in full_response):
357
- thinking_content += token
358
- thinking_placeholder.markdown(
359
- f'<details class="thinking" open><summary>🤔 Model Thinking</summary><pre>{thinking_content}</pre></details>',
360
- unsafe_allow_html=True
361
- )
362
-
363
- # After streaming, extract final reasoning (outside <think>...</think>)
364
- cleaned = re.sub(r"<think>.*?</think>", "", full_response, flags=re.DOTALL).strip()
365
- return thinking_content, cleaned
366
-
367
- # === DataFrameSummary TOOL (pandas only) =========================================
368
- def DataFrameSummaryTool(df: pd.DataFrame) -> str:
369
- """Generate a summary prompt string for the LLM based on the DataFrame."""
370
- prompt = f"""
371
- Given a dataset with {len(df)} rows and {len(df.columns)} columns:
372
- Columns: {', '.join(df.columns)}
373
- Data types: {df.dtypes.to_dict()}
374
- Missing values: {df.isnull().sum().to_dict()}
375
-
376
- Provide:
377
- 1. A brief description of what this dataset contains
378
- 2. 3-4 possible data analysis questions that could be explored
379
- Keep it concise and focused."""
380
- return prompt
381
-
382
- # === DataInsightAgent (upload-time only) ===============================
383
-
384
- def DataInsightAgent(df: pd.DataFrame) -> str:
385
- """Uses the LLM to generate a brief summary and possible questions for the uploaded dataset."""
386
- current_config = get_current_config()
387
- prompt = DataFrameSummaryTool(df)
388
- try:
389
- response = client.chat.completions.create(
390
- model=current_config.MODEL_NAME,
391
- messages=[
392
- {"role": "system", "content": current_config.REASONING_FALSE},
393
- {"role": "user", "content": "You are a data analyst providing brief, focused insights. " + prompt}
394
- ],
395
- temperature=current_config.INSIGHTS_TEMPERATURE,
396
- max_tokens=current_config.INSIGHTS_MAX_TOKENS
397
- )
398
- return response.choices[0].message.content
399
- except Exception as exc:
400
- raise Exception(f"Error generating dataset insights: {exc}")
401
-
402
- # === Helpers ===========================================================
403
-
404
- def extract_first_code_block(text: str) -> str:
405
- """Extracts the first Python code block from a markdown-formatted string."""
406
- start = text.find("```python")
407
- if start == -1:
408
- return ""
409
- start += len("```python")
410
- end = text.find("```", start)
411
- if end == -1:
412
- return ""
413
- return text[start:end].strip()
414
-
415
- # === Main Streamlit App ===============================================
416
-
417
- def main():
418
- st.set_page_config(layout="wide")
419
- if "plots" not in st.session_state:
420
- st.session_state.plots = []
421
- if "current_model" not in st.session_state:
422
- st.session_state.current_model = DEFAULT_MODEL
423
-
424
- # Page logo at top right corner, large and clickable
425
- st.markdown(
426
- """
427
- <div style='position: absolute; top: 20px; right: 30px; z-index: 999;'>
428
- <a href='https://www.linkedin.com/in/thiresh-sidda/' target='_blank'>
429
- <img src='https://ih1.redbubble.net/image.1849728168.3104/raf,360x360,075,t,fafafa:ca443f4786.jpg' alt='Logo' style='height:120px; border-radius:20px; box-shadow:0 2px 12px rgba(0,0,0,0.15);'>
430
- </a>
431
- </div>
432
- """,
433
- unsafe_allow_html=True
434
- )
435
- # Main title centered with large font and GIF
436
- st.markdown(
437
- """
438
- <div style='display: flex; align-items: center; justify-content: center; margin-bottom: 30px;'>
439
- <span style='color:#1976D2; font-weight:bold; font-size:3.5em; margin-right:30px;'>Data Analysis Agent</span>
440
- <img src='https://cdn.dribbble.com/userupload/23161671/file/original-4c7894556285d8f223ab21fd10554fe4.gif' alt='GIF' style='height:120px;'>
441
- </div>
442
- """,
443
- unsafe_allow_html=True
444
- )
445
-
446
- medium_blue = "#1976D2" # Medium blue color
447
-
448
- # Move left panel to sidebar
449
- with st.sidebar:
450
- st.markdown(f"<span style='color:{medium_blue}; font-weight:bold; font-size:1.5em;'>Insights Generator</span>", unsafe_allow_html=True)
451
- available_models = list(MODEL_CONFIGS.keys())
452
- model_display_names = {key: MODEL_CONFIGS[key].MODEL_PRINT_NAME for key in available_models}
453
- selected_model = st.selectbox(
454
- "Select Model",
455
- options=available_models,
456
- format_func=lambda x: model_display_names[x],
457
- index=available_models.index(st.session_state.current_model)
458
- )
459
- display_config = MODEL_CONFIGS[selected_model]
460
- file = st.file_uploader("Choose CSV", type=["csv"], key="csv_uploader")
461
- # Update configuration if model changed
462
- if selected_model != st.session_state.current_model:
463
- st.session_state.current_model = selected_model
464
- new_config = MODEL_CONFIGS[selected_model]
465
- if "messages" in st.session_state:
466
- st.session_state.messages = []
467
- if "plots" in st.session_state:
468
- st.session_state.plots = []
469
- if "df" in st.session_state and file is not None:
470
- with st.spinner("Generating dataset insights with new model …"):
471
- try:
472
- st.session_state.insights = DataInsightAgent(st.session_state.df)
473
- st.success(f"Insights updated with {new_config.MODEL_PRINT_NAME}")
474
- except Exception as e:
475
- st.error(f"Error updating insights: {str(e)}")
476
- if "insights" in st.session_state:
477
- del st.session_state.insights
478
- st.rerun()
479
- if not file and "df" in st.session_state and "current_file" in st.session_state:
480
- del st.session_state.df
481
- del st.session_state.current_file
482
- if "insights" in st.session_state:
483
- del st.session_state.insights
484
- st.rerun()
485
- if file:
486
- if ("df" not in st.session_state) or (st.session_state.get("current_file") != file.name):
487
- st.session_state.df = pd.read_csv(file)
488
- st.session_state.current_file = file.name
489
- st.session_state.messages = []
490
- with st.spinner("Generating dataset insights …"):
491
- try:
492
- st.session_state.insights = DataInsightAgent(st.session_state.df)
493
- except Exception as e:
494
- st.error(f"Error generating insights: {str(e)}")
495
- elif "insights" not in st.session_state:
496
- with st.spinner("Generating dataset insights …"):
497
- try:
498
- st.session_state.insights = DataInsightAgent(st.session_state.df)
499
- except Exception as e:
500
- st.error(f"Error generating insights: {str(e)}")
501
- if "df" in st.session_state:
502
- st.markdown(f"<span style='color:{medium_blue}; font-weight:bold; font-size:1.2em;'>Your Dataset Insights</span>", unsafe_allow_html=True)
503
- if "insights" in st.session_state and st.session_state.insights:
504
- st.dataframe(st.session_state.df.head())
505
- st.markdown(f"<span style='color:{medium_blue};'>{st.session_state.insights}</span>", unsafe_allow_html=True)
506
- current_config_left = get_current_config()
507
- #st.markdown(f"*<span style='color: grey; font-style: italic;'>Generated with {current_config_left.MODEL_PRINT_NAME}</span>*", unsafe_allow_html=True)
508
- else:
509
- st.warning("No insights available.")
510
- else:
511
- st.info("Upload a CSV to begin chatting with your data.")
512
-
513
- with st.container():
514
- st.markdown(
515
- f"""
516
- <div style='display: flex; align-items: center; justify-content: flex-start; margin-bottom: 10px;'>
517
- <span style='color:{medium_blue}; font-weight:bold; font-size:2em; margin-right:20px;'>Chat with your data</span>
518
- <img src='https://i.pinimg.com/originals/5f/d5/58/5fd558f8b7a4f9e2138709cbe63c7052.gif' alt='Chat GIF' style='height:48px;'>
519
- </div>
520
- """,
521
- unsafe_allow_html=True
522
- )
523
- if "df" in st.session_state:
524
- current_config_right = get_current_config()
525
- st.markdown(f"*<span style='color: grey; font-style: italic;'>Using {current_config_right.MODEL_PRINT_NAME}</span>*", unsafe_allow_html=True)
526
- if "messages" not in st.session_state:
527
- st.session_state.messages = []
528
-
529
- clear_col1, clear_col2 = st.columns([9,1])
530
- with clear_col2:
531
- if st.button("Clear chat"):
532
- st.session_state.messages = []
533
- st.session_state.plots = []
534
- st.rerun()
535
-
536
- for msg in st.session_state.messages:
537
- with st.chat_message(msg["role"]):
538
- st.markdown(f"<span style='color:{medium_blue}; font-size:1.1em;'>{msg['content']}</span>", unsafe_allow_html=True)
539
- if msg.get("plot_index") is not None:
540
- idx = msg["plot_index"]
541
- if 0 <= idx < len(st.session_state.plots):
542
- st.pyplot(st.session_state.plots[idx], use_container_width=False)
543
-
544
- if "df" in st.session_state:
545
- if user_q := st.chat_input("Ask about your data…"):
546
- st.session_state.messages.append({"role": "user", "content": user_q})
547
- with st.spinner("Working …"):
548
- recent_user_turns = [m["content"] for m in st.session_state.messages if m["role"] == "user"][-3:]
549
- context_text = "\n".join(recent_user_turns[:-1]) if len(recent_user_turns) > 1 else None
550
- code, should_plot_flag, code_thinking = CodeGenerationAgent(user_q, st.session_state.df, context_text)
551
- result_obj = ExecutionAgent(code, st.session_state.df, should_plot_flag)
552
- raw_thinking, reasoning_txt = ReasoningAgent(user_q, result_obj)
553
- reasoning_txt = reasoning_txt.replace("`", "")
554
-
555
- is_plot = isinstance(result_obj, (plt.Figure, plt.Axes))
556
- plot_idx = None
557
- if is_plot:
558
- fig = result_obj.figure if isinstance(result_obj, plt.Axes) else result_obj
559
- st.session_state.plots.append(fig)
560
- plot_idx = len(st.session_state.plots) - 1
561
- header = "Here is the visualization you requested:"
562
- elif isinstance(result_obj, (pd.DataFrame, pd.Series)):
563
- header = f"Result: {len(result_obj)} rows" if isinstance(result_obj, pd.DataFrame) else "Result series"
564
- else:
565
- header = f"Result: {result_obj}"
566
-
567
- thinking_html = ""
568
- if raw_thinking:
569
- thinking_html = (
570
- '<details class="thinking">'
571
- '<summary>🧠 Reasoning</summary>'
572
- f'<pre>{raw_thinking}</pre>'
573
- '</details>'
574
- )
575
-
576
- explanation_html = reasoning_txt
577
-
578
- code_html = (
579
- '<details class="code">'
580
- '<summary>View code</summary>'
581
- '<pre><code class="language-python">'
582
- f'{code}'
583
- '</code></pre>'
584
- '</details>'
585
- )
586
- assistant_msg = f"{thinking_html}{explanation_html}\n\n{code_html}"
587
-
588
- st.session_state.messages.append({
589
- "role": "assistant",
590
- "content": assistant_msg,
591
- "plot_index": plot_idx
592
- })
593
- st.rerun()
594
-
595
- if __name__ == "__main__":
596
  main()
 
1
+ import os, io, re
2
+ import pandas as pd
3
+ import numpy as np
4
+ import streamlit as st
5
+ from openai import OpenAI
6
+ import matplotlib.pyplot as plt
7
+ from typing import List, Any, Optional
8
+
9
+ # === Configuration ===
10
+ # Global configuration
11
+ API_BASE_URL = "https://integrate.api.nvidia.com/v1"
12
+ API_KEY = os.environ.get("NVIDIA_API_KEY")
13
+
14
+ # Plot configuration
15
+ DEFAULT_FIGSIZE = (6, 4)
16
+ DEFAULT_DPI = 100
17
+
18
+ # Display configuration
19
+ MAX_RESULT_DISPLAY_LENGTH = 300
20
+
21
+ class ModelConfig:
22
+ """Configuration class for different models."""
23
+
24
+ def __init__(self, model_name: str, model_url: str, model_print_name: str,
25
+ # QueryUnderstandingTool parameters
26
+ query_understanding_temperature: float = 0.1,
27
+ query_understanding_max_tokens: int = 5,
28
+ # CodeGenerationAgent parameters
29
+ code_generation_temperature: float = 0.2,
30
+ code_generation_max_tokens: int = 1024,
31
+ # ReasoningAgent parameters
32
+ reasoning_temperature: float = 0.2,
33
+ reasoning_max_tokens: int = 1024,
34
+ # DataInsightAgent parameters
35
+ insights_temperature: float = 0.2,
36
+ insights_max_tokens: int = 512,
37
+ reasoning_false: str = "detailed thinking off",
38
+ reasoning_true: str = "detailed thinking on"):
39
+ self.MODEL_NAME = model_name
40
+ self.MODEL_URL = model_url
41
+ self.MODEL_PRINT_NAME = model_print_name
42
+
43
+ # Function-specific LLM parameters
44
+ self.QUERY_UNDERSTANDING_TEMPERATURE = query_understanding_temperature
45
+ self.QUERY_UNDERSTANDING_MAX_TOKENS = query_understanding_max_tokens
46
+ self.CODE_GENERATION_TEMPERATURE = code_generation_temperature
47
+ self.CODE_GENERATION_MAX_TOKENS = code_generation_max_tokens
48
+ self.REASONING_TEMPERATURE = reasoning_temperature
49
+ self.REASONING_MAX_TOKENS = reasoning_max_tokens
50
+ self.INSIGHTS_TEMPERATURE = insights_temperature
51
+ self.INSIGHTS_MAX_TOKENS = insights_max_tokens
52
+ self.REASONING_FALSE = reasoning_false
53
+ self.REASONING_TRUE = reasoning_true
54
+
55
+ # Predefined model configurations
56
+ MODEL_CONFIGS = {
57
+ "llama-3-1-nemotron-ultra-v1": ModelConfig(
58
+ model_name="nvidia/llama-3.1-nemotron-ultra-253b-v1",
59
+ model_url="https://build.nvidia.com/nvidia/llama-3_1-nemotron-ultra-253b-v1",
60
+ model_print_name="NVIDIA Llama 3.1 Nemotron Ultra 253B v1",
61
+ # QueryUnderstandingTool
62
+ query_understanding_temperature=0.1,
63
+ query_understanding_max_tokens=5,
64
+ # CodeGenerationAgent
65
+ code_generation_temperature=0.2,
66
+ code_generation_max_tokens=1024,
67
+ # ReasoningAgent
68
+ reasoning_temperature=0.6,
69
+ reasoning_max_tokens=1024,
70
+ # DataInsightAgent
71
+ insights_temperature=0.2,
72
+ insights_max_tokens=512,
73
+ reasoning_false="detailed thinking off",
74
+ reasoning_true="detailed thinking on"
75
+ ),
76
+ "llama-3-3-nemotron-super-v1-5": ModelConfig(
77
+ model_name="nvidia/llama-3.3-nemotron-super-49b-v1.5",
78
+ model_url="https://build.nvidia.com/nvidia/llama-3_3-nemotron-super-49b-v1_5",
79
+ model_print_name="NVIDIA Llama 3.3 Nemotron Super 49B v1.5",
80
+ # QueryUnderstandingTool
81
+ query_understanding_temperature=0.1,
82
+ query_understanding_max_tokens=5,
83
+ # CodeGenerationAgent
84
+ code_generation_temperature=0.0,
85
+ code_generation_max_tokens=1024,
86
+ # ReasoningAgent
87
+ reasoning_temperature=0.6,
88
+ reasoning_max_tokens=2048,
89
+ # DataInsightAgent
90
+ insights_temperature=0.2,
91
+ insights_max_tokens=512,
92
+ reasoning_false="/no_think",
93
+ reasoning_true=""
94
+ )
95
+ }
96
+
97
+ # Default configuration (can be changed via environment variable or UI)
98
+ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "llama-3-1-nemotron-ultra-v1")
99
+ Config = MODEL_CONFIGS.get(DEFAULT_MODEL, MODEL_CONFIGS["llama-3-1-nemotron-ultra-v1"])
100
+
101
+ # Initialize OpenAI client with configuration
102
+ client = OpenAI(
103
+ base_url=API_BASE_URL,
104
+ api_key=API_KEY
105
+ )
106
+
107
+ def get_current_config():
108
+ """Get the current model configuration based on session state."""
109
+ # Always return the current model from session state
110
+ if "current_model" in st.session_state:
111
+ return MODEL_CONFIGS[st.session_state.current_model]
112
+
113
+ return MODEL_CONFIGS[DEFAULT_MODEL]
114
+
115
+ # ------------------ QueryUnderstandingTool ---------------------------
116
+ def QueryUnderstandingTool(query: str) -> bool:
117
+ """Return True if the query seems to request a visualisation based on keywords."""
118
+ # Use LLM to understand intent instead of keyword matching
119
+ current_config = get_current_config()
120
+
121
+ # Prepend the instruction to the query
122
+ full_prompt = f"""You are a query classifier. Your task is to determine if a user query is requesting a data visualization.
123
+
124
+ IMPORTANT: Respond with ONLY 'true' or 'false' (lowercase, no quotes, no punctuation).
125
+
126
+ Classify as 'true' ONLY if the query explicitly asks for:
127
+ - A plot, chart, graph, visualization, or figure
128
+ - To "show" or "display" data visually
129
+ - To "create" or "generate" a visual representation
130
+ - Words like: plot, chart, graph, visualize, show, display, create, generate, draw
131
+
132
+ Classify as 'false' for:
133
+ - Data analysis without visualization requests
134
+ - Statistical calculations, aggregations, filtering, sorting
135
+ - Questions about data content, counts, summaries
136
+ - Requests for tables, dataframes, or text results
137
+
138
+ User query: {query}"""
139
+
140
+ messages = [
141
+ {"role": "system", "content": current_config.REASONING_FALSE},
142
+ {"role": "user", "content": full_prompt}
143
+ ]
144
+
145
+ response = client.chat.completions.create(
146
+ model=current_config.MODEL_NAME,
147
+ messages=messages,
148
+ temperature=current_config.QUERY_UNDERSTANDING_TEMPERATURE,
149
+ max_tokens=current_config.QUERY_UNDERSTANDING_MAX_TOKENS # We only need a short response
150
+ )
151
+
152
+ # Extract the response and convert to boolean
153
+
154
+ intent_response = response.choices[0].message.content.strip().lower()
155
+
156
+ return intent_response == "true"
157
+
158
+ # === CodeGeneration TOOLS ============================================
159
+
160
+
161
+ # ------------------ CodeWritingTool ---------------------------------
162
+ def CodeWritingTool(cols: List[str], query: str) -> str:
163
+ """Generate a prompt for the LLM to write pandas-only code for a data query (no plotting)."""
164
+
165
+ return f"""
166
+
167
+ Given DataFrame `df` with columns:
168
+
169
+ {', '.join(cols)}
170
+
171
+ Write Python code (pandas **only**, no plotting) to answer:
172
+ "{query}"
173
+
174
+ Rules
175
+ -----
176
+ 1. Use pandas operations on `df` only.
177
+ 2. Rely only on the columns in the DataFrame.
178
+ 3. Assign the final result to `result`.
179
+ 4. Return your answer inside a single markdown fence that starts with ```python and ends with ```.
180
+ 5. Do not include any explanations, comments, or prose outside the code block.
181
+ 6. Use **df** as the sole data source. **Do not** read files, fetch data, or use Streamlit.
182
+ 7. Do **not** import any libraries (pandas is already imported as pd).
183
+ 8. Handle missing values (`dropna`) before aggregations.
184
+
185
+ Example
186
+ -----
187
+ ```python
188
+ result = df.groupby("some_column")["a_numeric_col"].mean().sort_values(ascending=False)
189
+ ```
190
+
191
+ """
192
+
193
+
194
+ # ------------------ PlotCodeGeneratorTool ---------------------------
195
+ def PlotCodeGeneratorTool(cols: List[str], query: str) -> str:
196
+
197
+ """Generate a prompt for the LLM to write pandas + matplotlib code for a plot based on the query and columns."""
198
+
199
+ return f"""
200
+
201
+ Given DataFrame `df` with columns:
202
+
203
+ {', '.join(cols)}
204
+
205
+ Write Python code using pandas **and matplotlib** (as plt) to answer:
206
+ "{query}"
207
+
208
+ Rules
209
+ -----
210
+ 1. Use pandas for data manipulation and matplotlib.pyplot (as plt) for plotting.
211
+ 2. Rely only on the columns in the DataFrame.
212
+ 3. Assign the final result (DataFrame, Series, scalar *or* matplotlib Figure) to a variable named `result`.
213
+ 4. Create only ONE relevant plot. Set `figsize={DEFAULT_FIGSIZE}`, add title/labels.
214
+ 5. Return your answer inside a single markdown fence that starts with ```python and ends with ```.
215
+ 6. Do not include any explanations, comments, or prose outside the code block.
216
+ 7. Handle missing values (`dropna`) before plotting/aggregations.
217
+
218
+ """
219
+
220
+
221
+ # === CodeGenerationAgent ==============================================
222
+
223
+ def CodeGenerationAgent(query: str, df: pd.DataFrame, chat_context: Optional[str] = None):
224
+ """Selects the appropriate code generation tool and gets code from the LLM for the user's query."""
225
+
226
+ should_plot = QueryUnderstandingTool(query)
227
+
228
+ prompt = PlotCodeGeneratorTool(df.columns.tolist(), query) if should_plot else CodeWritingTool(df.columns.tolist(), query)
229
+
230
+ # Prepend the instruction to the query
231
+ context_section = f"\nConversation context (recent user turns):\n{chat_context}\n" if chat_context else ""
232
+
233
+ full_prompt = f"""You are a senior Python data analyst who writes clean, efficient code.
234
+ Solve the given problem with optimal pandas operations. Be concise and focused.
235
+ Your response must contain ONLY a properly-closed ```python code block with no explanations before or after (starts with ```python and ends with ```).
236
+ Ensure your solution is correct, handles edge cases, and follows best practices for data analysis.
237
+ If the latest user request references prior results ambiguously (e.g., "it", "that", "same groups"), infer intent from the conversation context and choose the most reasonable interpretation. {context_section}{prompt}"""
238
+
239
+ current_config = get_current_config()
240
+
241
+ messages = [
242
+ {"role": "system", "content": current_config.REASONING_FALSE},
243
+ {"role": "user", "content": full_prompt}
244
+ ]
245
+
246
+ response = client.chat.completions.create(
247
+ model=current_config.MODEL_NAME,
248
+ messages=messages,
249
+ temperature=current_config.CODE_GENERATION_TEMPERATURE,
250
+ max_tokens=current_config.CODE_GENERATION_MAX_TOKENS
251
+ )
252
+
253
+ full_response = response.choices[0].message.content
254
+
255
+ code = extract_first_code_block(full_response)
256
+ return code, should_plot, ""
257
+
258
+ # === ExecutionAgent ====================================================
259
+
260
+ def ExecutionAgent(code: str, df: pd.DataFrame, should_plot: bool):
261
+ """Executes the generated code in a controlled environment and returns the result or error message."""
262
+
263
+ # Set up execution environment with all necessary modules
264
+ env = {
265
+ "pd": pd,
266
+ "df": df
267
+ }
268
+
269
+ if should_plot:
270
+ plt.rcParams["figure.dpi"] = DEFAULT_DPI # Set default DPI for all figures
271
+ env["plt"] = plt
272
+ env["io"] = io
273
+
274
+ try:
275
+ # Execute the code in the environment
276
+ exec(code, {}, env)
277
+ result = env.get("result", None)
278
+
279
+ # If no result was assigned, return the last expression
280
+ if result is None:
281
+ # Try to get the last executed expression
282
+ if "result" not in env:
283
+ return "No result was assigned to 'result' variable"
284
+
285
+ return result
286
+ except Exception as exc:
287
+ return f"Error executing code: {exc}"
288
+
289
+ # === ReasoningCurator TOOL =========================================
290
+ def ReasoningCurator(query: str, result: Any) -> str:
291
+ """Builds and returns the LLM prompt for reasoning about the result."""
292
+ is_error = isinstance(result, str) and result.startswith("Error executing code")
293
+ is_plot = isinstance(result, (plt.Figure, plt.Axes))
294
+
295
+ if is_error:
296
+ desc = result
297
+ elif is_plot:
298
+ title = ""
299
+ if isinstance(result, plt.Figure):
300
+ title = result._suptitle.get_text() if result._suptitle else ""
301
+ elif isinstance(result, plt.Axes):
302
+ title = result.get_title()
303
+ desc = f"[Plot Object: {title or 'Chart'}]"
304
+ else:
305
+ desc = str(result)[:MAX_RESULT_DISPLAY_LENGTH]
306
+
307
+ if is_plot:
308
+ prompt = f'''
309
+ The user asked: "{query}".
310
+ Below is a description of the plot result:
311
+ {desc}
312
+ Explain in 2–3 concise sentences what the chart shows (no code talk).'''
313
+ else:
314
+ prompt = f'''
315
+ The user asked: "{query}".
316
+ The result value is: {desc}
317
+ Explain in 2–3 concise sentences what this tells about the data (no mention of charts).'''
318
+ return prompt
319
+
320
+ # === ReasoningAgent (streaming) =========================================
321
+ def ReasoningAgent(query: str, result: Any):
322
+ """Streams the LLM's reasoning about the result (plot or value) and extracts model 'thinking' and final explanation."""
323
+ current_config = get_current_config()
324
+ prompt = ReasoningCurator(query, result)
325
+
326
+ # Streaming LLM call
327
+ response = client.chat.completions.create(
328
+ model=current_config.MODEL_NAME,
329
+ messages=[
330
+ {"role": "system", "content": current_config.REASONING_TRUE},
331
+ {"role": "user", "content": "You are an insightful data analyst. " + prompt}
332
+ ],
333
+ temperature=current_config.REASONING_TEMPERATURE,
334
+ max_tokens=current_config.REASONING_MAX_TOKENS,
335
+ stream=True
336
+ )
337
+
338
+ # Stream and display thinking
339
+ thinking_placeholder = st.empty()
340
+ full_response = ""
341
+ thinking_content = ""
342
+ in_think = False
343
+
344
+ for chunk in response:
345
+ if chunk.choices[0].delta.content is not None:
346
+ token = chunk.choices[0].delta.content
347
+ full_response += token
348
+
349
+ # Simple state machine to extract <think>...</think> as it streams
350
+ if "<think>" in token:
351
+ in_think = True
352
+ token = token.split("<think>", 1)[1]
353
+ if "</think>" in token:
354
+ token = token.split("</think>", 1)[0]
355
+ in_think = False
356
+ if in_think or ("<think>" in full_response and not "</think>" in full_response):
357
+ thinking_content += token
358
+ thinking_placeholder.markdown(
359
+ f'<details class="thinking" open><summary>🤔 Model Thinking</summary><pre>{thinking_content}</pre></details>',
360
+ unsafe_allow_html=True
361
+ )
362
+
363
+ # After streaming, extract final reasoning (outside <think>...</think>)
364
+ cleaned = re.sub(r"<think>.*?</think>", "", full_response, flags=re.DOTALL).strip()
365
+ return thinking_content, cleaned
366
+
367
+ # === DataFrameSummary TOOL (pandas only) =========================================
368
+ def DataFrameSummaryTool(df: pd.DataFrame) -> str:
369
+ """Generate a summary prompt string for the LLM based on the DataFrame."""
370
+ prompt = f"""
371
+ Given a dataset with {len(df)} rows and {len(df.columns)} columns:
372
+ Columns: {', '.join(df.columns)}
373
+ Data types: {df.dtypes.to_dict()}
374
+ Missing values: {df.isnull().sum().to_dict()}
375
+
376
+ Provide:
377
+ 1. A brief description of what this dataset contains
378
+ 2. 3-4 possible data analysis questions that could be explored
379
+ Keep it concise and focused."""
380
+ return prompt
381
+
382
+ # === DataInsightAgent (upload-time only) ===============================
383
+
384
+ def DataInsightAgent(df: pd.DataFrame) -> str:
385
+ """Uses the LLM to generate a brief summary and possible questions for the uploaded dataset."""
386
+ current_config = get_current_config()
387
+ prompt = DataFrameSummaryTool(df)
388
+ try:
389
+ response = client.chat.completions.create(
390
+ model=current_config.MODEL_NAME,
391
+ messages=[
392
+ {"role": "system", "content": current_config.REASONING_FALSE},
393
+ {"role": "user", "content": "You are a data analyst providing brief, focused insights. " + prompt}
394
+ ],
395
+ temperature=current_config.INSIGHTS_TEMPERATURE,
396
+ max_tokens=current_config.INSIGHTS_MAX_TOKENS
397
+ )
398
+ return response.choices[0].message.content
399
+ except Exception as exc:
400
+ raise Exception(f"Error generating dataset insights: {exc}")
401
+
402
+ # === Helpers ===========================================================
403
+
404
+ def extract_first_code_block(text: str) -> str:
405
+ """Extracts the first Python code block from a markdown-formatted string."""
406
+ start = text.find("```python")
407
+ if start == -1:
408
+ return ""
409
+ start += len("```python")
410
+ end = text.find("```", start)
411
+ if end == -1:
412
+ return ""
413
+ return text[start:end].strip()
414
+
415
+ # === Main Streamlit App ===============================================
416
+
417
+ def main():
418
+ st.set_page_config(layout="wide")
419
+ if "plots" not in st.session_state:
420
+ st.session_state.plots = []
421
+ if "current_model" not in st.session_state:
422
+ st.session_state.current_model = DEFAULT_MODEL
423
+
424
+ # Page logo at top right corner, large and clickable
425
+ st.markdown(
426
+ """
427
+ <div style='position: absolute; top: 20px; right: 30px; z-index: 999;'>
428
+ <a href='https://www.linkedin.com/in/thiresh-sidda/' target='_blank'>
429
+ <img src='https://ih1.redbubble.net/image.1849728168.3104/raf,360x360,075,t,fafafa:ca443f4786.jpg' alt='Logo' style='height:120px; border-radius:20px; box-shadow:0 2px 12px rgba(0,0,0,0.15);'>
430
+ </a>
431
+ </div>
432
+ """,
433
+ unsafe_allow_html=True
434
+ )
435
+ # Main title centered with large font and GIF
436
+ st.markdown(
437
+ """
438
+ <div style='display: flex; align-items: center; justify-content: center; margin-bottom: 30px;'>
439
+ <span style='color:#1976D2; font-weight:bold; font-size:3.5em; margin-right:30px;'>Data Analysis Agent</span>
440
+ <img src='https://cdn.dribbble.com/userupload/23161671/file/original-4c7894556285d8f223ab21fd10554fe4.gif' alt='GIF' style='height:120px;'>
441
+ </div>
442
+ """,
443
+ unsafe_allow_html=True
444
+ )
445
+
446
+ medium_blue = "#1976D2" # Medium blue color
447
+
448
+ # Move left panel to sidebar
449
+ with st.sidebar:
450
+ st.markdown(f"<span style='color:{medium_blue}; font-weight:bold; font-size:1.5em;'>Insights Generator</span>", unsafe_allow_html=True)
451
+ available_models = list(MODEL_CONFIGS.keys())
452
+ model_display_names = {key: MODEL_CONFIGS[key].MODEL_PRINT_NAME for key in available_models}
453
+ selected_model = st.selectbox(
454
+ "Select Model",
455
+ options=available_models,
456
+ format_func=lambda x: model_display_names[x],
457
+ index=available_models.index(st.session_state.current_model)
458
+ )
459
+ display_config = MODEL_CONFIGS[selected_model]
460
+ file = st.file_uploader("Choose CSV", type=["csv"], key="csv_uploader")
461
+ # Update configuration if model changed
462
+ if selected_model != st.session_state.current_model:
463
+ st.session_state.current_model = selected_model
464
+ new_config = MODEL_CONFIGS[selected_model]
465
+ if "messages" in st.session_state:
466
+ st.session_state.messages = []
467
+ if "plots" in st.session_state:
468
+ st.session_state.plots = []
469
+ if "df" in st.session_state and file is not None:
470
+ with st.spinner("Generating dataset insights with new model …"):
471
+ try:
472
+ st.session_state.insights = DataInsightAgent(st.session_state.df)
473
+ st.success(f"Insights updated with {new_config.MODEL_PRINT_NAME}")
474
+ except Exception as e:
475
+ st.error(f"Error updating insights: {str(e)}")
476
+ if "insights" in st.session_state:
477
+ del st.session_state.insights
478
+ st.rerun()
479
+ if not file and "df" in st.session_state and "current_file" in st.session_state:
480
+ del st.session_state.df
481
+ del st.session_state.current_file
482
+ if "insights" in st.session_state:
483
+ del st.session_state.insights
484
+ st.rerun()
485
+ if file:
486
+ if ("df" not in st.session_state) or (st.session_state.get("current_file") != file.name):
487
+ st.session_state.df = pd.read_csv(file)
488
+ st.session_state.current_file = file.name
489
+ st.session_state.messages = []
490
+ with st.spinner("Generating dataset insights …"):
491
+ try:
492
+ st.session_state.insights = DataInsightAgent(st.session_state.df)
493
+ except Exception as e:
494
+ st.error(f"Error generating insights: {str(e)}")
495
+ elif "insights" not in st.session_state:
496
+ with st.spinner("Generating dataset insights …"):
497
+ try:
498
+ st.session_state.insights = DataInsightAgent(st.session_state.df)
499
+ except Exception as e:
500
+ st.error(f"Error generating insights: {str(e)}")
501
+ if "df" in st.session_state:
502
+ st.markdown(f"<span style='color:{medium_blue}; font-weight:bold; font-size:1.2em;'>Your Dataset Insights</span>", unsafe_allow_html=True)
503
+ if "insights" in st.session_state and st.session_state.insights:
504
+ st.dataframe(st.session_state.df.head())
505
+ st.markdown(f"<span style='color:{medium_blue};'>{st.session_state.insights}</span>", unsafe_allow_html=True)
506
+ current_config_left = get_current_config()
507
+ #st.markdown(f"*<span style='color: grey; font-style: italic;'>Generated with {current_config_left.MODEL_PRINT_NAME}</span>*", unsafe_allow_html=True)
508
+ else:
509
+ st.warning("No insights available.")
510
+ else:
511
+ st.info("Upload a CSV to begin chatting with your data.")
512
+
513
+ with st.container():
514
+ st.markdown(
515
+ f"""
516
+ <div style='display: flex; align-items: center; justify-content: flex-start; margin-bottom: 10px;'>
517
+ <span style='color:{medium_blue}; font-weight:bold; font-size:2em; margin-right:20px;'>Chat with your data</span>
518
+ <img src='https://i.pinimg.com/originals/5f/d5/58/5fd558f8b7a4f9e2138709cbe63c7052.gif' alt='Chat GIF' style='height:48px;'>
519
+ </div>
520
+ """,
521
+ unsafe_allow_html=True
522
+ )
523
+ if "df" in st.session_state:
524
+ current_config_right = get_current_config()
525
+ st.markdown(f"*<span style='color: grey; font-style: italic;'>Using {current_config_right.MODEL_PRINT_NAME}</span>*", unsafe_allow_html=True)
526
+ if "messages" not in st.session_state:
527
+ st.session_state.messages = []
528
+
529
+ clear_col1, clear_col2 = st.columns([9,1])
530
+ with clear_col2:
531
+ if st.button("Clear chat"):
532
+ st.session_state.messages = []
533
+ st.session_state.plots = []
534
+ st.rerun()
535
+
536
+ for msg in st.session_state.messages:
537
+ with st.chat_message(msg["role"]):
538
+ st.markdown(f"<span style='color:{medium_blue}; font-size:1.1em;'>{msg['content']}</span>", unsafe_allow_html=True)
539
+ if msg.get("plot_index") is not None:
540
+ idx = msg["plot_index"]
541
+ if 0 <= idx < len(st.session_state.plots):
542
+ st.pyplot(st.session_state.plots[idx], use_container_width=False)
543
+
544
+ if "df" in st.session_state:
545
+ if user_q := st.chat_input("Ask about your data…"):
546
+ st.session_state.messages.append({"role": "user", "content": user_q})
547
+ with st.spinner("Working …"):
548
+ recent_user_turns = [m["content"] for m in st.session_state.messages if m["role"] == "user"][-3:]
549
+ context_text = "\n".join(recent_user_turns[:-1]) if len(recent_user_turns) > 1 else None
550
+ code, should_plot_flag, code_thinking = CodeGenerationAgent(user_q, st.session_state.df, context_text)
551
+ result_obj = ExecutionAgent(code, st.session_state.df, should_plot_flag)
552
+ raw_thinking, reasoning_txt = ReasoningAgent(user_q, result_obj)
553
+ reasoning_txt = reasoning_txt.replace("`", "")
554
+
555
+ is_plot = isinstance(result_obj, (plt.Figure, plt.Axes))
556
+ plot_idx = None
557
+ if is_plot:
558
+ fig = result_obj.figure if isinstance(result_obj, plt.Axes) else result_obj
559
+ st.session_state.plots.append(fig)
560
+ plot_idx = len(st.session_state.plots) - 1
561
+ header = "Here is the visualization you requested:"
562
+ elif isinstance(result_obj, (pd.DataFrame, pd.Series)):
563
+ header = f"Result: {len(result_obj)} rows" if isinstance(result_obj, pd.DataFrame) else "Result series"
564
+ else:
565
+ header = f"Result: {result_obj}"
566
+
567
+ thinking_html = ""
568
+ if raw_thinking:
569
+ thinking_html = (
570
+ '<details class="thinking">'
571
+ '<summary>🧠 Reasoning</summary>'
572
+ f'<pre>{raw_thinking}</pre>'
573
+ '</details>'
574
+ )
575
+
576
+ explanation_html = reasoning_txt
577
+
578
+ code_html = (
579
+ '<details class="code">'
580
+ '<summary>View code</summary>'
581
+ '<pre><code class="language-python">'
582
+ f'{code}'
583
+ '</code></pre>'
584
+ '</details>'
585
+ )
586
+ assistant_msg = f"{thinking_html}{explanation_html}\n\n{code_html}"
587
+
588
+ st.session_state.messages.append({
589
+ "role": "assistant",
590
+ "content": assistant_msg,
591
+ "plot_index": plot_idx
592
+ })
593
+ st.rerun()
594
+
595
+ if __name__ == "__main__":
596
  main()