Thiresh commited on
Commit
ba4339f
·
verified ·
1 Parent(s): 11e04a9

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +25 -20
  2. data_analysis_agent.py +596 -0
  3. requirements.txt +6 -3
Dockerfile CHANGED
@@ -1,20 +1,25 @@
1
- FROM python:3.13.5-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- git \
9
- && rm -rf /var/lib/apt/lists/*
10
-
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
13
-
14
- RUN pip3 install -r requirements.txt
15
-
16
- EXPOSE 8501
17
-
18
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
-
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
 
 
 
 
1
+ # Use a minimal Python base image
2
+ FROM python:3.9-slim
3
+
4
+ # Set working directory inside the container
5
+ WORKDIR /app
6
+
7
+ # Copy requirements first for caching
8
+ COPY requirements.txt .
9
+
10
+ # Install dependencies
11
+ RUN pip install --upgrade pip && \
12
+ pip install -r requirements.txt
13
+
14
+ # Copy all files into the container
15
+ COPY . .
16
+
17
+ # Expose Streamlit's default port
18
+ EXPOSE 8501
19
+
20
+ # Run the Streamlit app
21
+ # This passes the NVIDIA_API_KEY from the Hugging Face Space Secrets
22
+ CMD ["streamlit", "run", "data_analysis_agent.py", \
23
+ "--server.port=8501", \
24
+ "--server.address=0.0.0.0", \
25
+ "--server.enableXsrfProtection=false"]
data_analysis_agent.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, re
2
+ import pandas as pd
3
+ import numpy as np
4
+ import streamlit as st
5
+ from openai import OpenAI
6
+ import matplotlib.pyplot as plt
7
+ from typing import List, Any, Optional
8
+
9
+ # === Configuration ===
10
+ # Global configuration
11
+ API_BASE_URL = "https://integrate.api.nvidia.com/v1"
12
+ API_KEY = "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" #os.environ.get("NVIDIA_API_KEY")
13
+
14
+ # Plot configuration
15
+ DEFAULT_FIGSIZE = (6, 4)
16
+ DEFAULT_DPI = 100
17
+
18
+ # Display configuration
19
+ MAX_RESULT_DISPLAY_LENGTH = 300
20
+
21
+ class ModelConfig:
22
+ """Configuration class for different models."""
23
+
24
+ def __init__(self, model_name: str, model_url: str, model_print_name: str,
25
+ # QueryUnderstandingTool parameters
26
+ query_understanding_temperature: float = 0.1,
27
+ query_understanding_max_tokens: int = 5,
28
+ # CodeGenerationAgent parameters
29
+ code_generation_temperature: float = 0.2,
30
+ code_generation_max_tokens: int = 1024,
31
+ # ReasoningAgent parameters
32
+ reasoning_temperature: float = 0.2,
33
+ reasoning_max_tokens: int = 1024,
34
+ # DataInsightAgent parameters
35
+ insights_temperature: float = 0.2,
36
+ insights_max_tokens: int = 512,
37
+ reasoning_false: str = "detailed thinking off",
38
+ reasoning_true: str = "detailed thinking on"):
39
+ self.MODEL_NAME = model_name
40
+ self.MODEL_URL = model_url
41
+ self.MODEL_PRINT_NAME = model_print_name
42
+
43
+ # Function-specific LLM parameters
44
+ self.QUERY_UNDERSTANDING_TEMPERATURE = query_understanding_temperature
45
+ self.QUERY_UNDERSTANDING_MAX_TOKENS = query_understanding_max_tokens
46
+ self.CODE_GENERATION_TEMPERATURE = code_generation_temperature
47
+ self.CODE_GENERATION_MAX_TOKENS = code_generation_max_tokens
48
+ self.REASONING_TEMPERATURE = reasoning_temperature
49
+ self.REASONING_MAX_TOKENS = reasoning_max_tokens
50
+ self.INSIGHTS_TEMPERATURE = insights_temperature
51
+ self.INSIGHTS_MAX_TOKENS = insights_max_tokens
52
+ self.REASONING_FALSE = reasoning_false
53
+ self.REASONING_TRUE = reasoning_true
54
+
55
+ # Predefined model configurations
56
+ MODEL_CONFIGS = {
57
+ "llama-3-1-nemotron-ultra-v1": ModelConfig(
58
+ model_name="nvidia/llama-3.1-nemotron-ultra-253b-v1",
59
+ model_url="https://build.nvidia.com/nvidia/llama-3_1-nemotron-ultra-253b-v1",
60
+ model_print_name="NVIDIA Llama 3.1 Nemotron Ultra 253B v1",
61
+ # QueryUnderstandingTool
62
+ query_understanding_temperature=0.1,
63
+ query_understanding_max_tokens=5,
64
+ # CodeGenerationAgent
65
+ code_generation_temperature=0.2,
66
+ code_generation_max_tokens=1024,
67
+ # ReasoningAgent
68
+ reasoning_temperature=0.6,
69
+ reasoning_max_tokens=1024,
70
+ # DataInsightAgent
71
+ insights_temperature=0.2,
72
+ insights_max_tokens=512,
73
+ reasoning_false="detailed thinking off",
74
+ reasoning_true="detailed thinking on"
75
+ ),
76
+ "llama-3-3-nemotron-super-v1-5": ModelConfig(
77
+ model_name="nvidia/llama-3.3-nemotron-super-49b-v1.5",
78
+ model_url="https://build.nvidia.com/nvidia/llama-3_3-nemotron-super-49b-v1_5",
79
+ model_print_name="NVIDIA Llama 3.3 Nemotron Super 49B v1.5",
80
+ # QueryUnderstandingTool
81
+ query_understanding_temperature=0.1,
82
+ query_understanding_max_tokens=5,
83
+ # CodeGenerationAgent
84
+ code_generation_temperature=0.0,
85
+ code_generation_max_tokens=1024,
86
+ # ReasoningAgent
87
+ reasoning_temperature=0.6,
88
+ reasoning_max_tokens=2048,
89
+ # DataInsightAgent
90
+ insights_temperature=0.2,
91
+ insights_max_tokens=512,
92
+ reasoning_false="/no_think",
93
+ reasoning_true=""
94
+ )
95
+ }
96
+
97
+ # Default configuration (can be changed via environment variable or UI)
98
+ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "llama-3-1-nemotron-ultra-v1")
99
+ Config = MODEL_CONFIGS.get(DEFAULT_MODEL, MODEL_CONFIGS["llama-3-1-nemotron-ultra-v1"])
100
+
101
+ # Initialize OpenAI client with configuration
102
+ client = OpenAI(
103
+ base_url=API_BASE_URL,
104
+ api_key=API_KEY
105
+ )
106
+
107
+ def get_current_config():
108
+ """Get the current model configuration based on session state."""
109
+ # Always return the current model from session state
110
+ if "current_model" in st.session_state:
111
+ return MODEL_CONFIGS[st.session_state.current_model]
112
+
113
+ return MODEL_CONFIGS[DEFAULT_MODEL]
114
+
115
+ # ------------------ QueryUnderstandingTool ---------------------------
116
+ def QueryUnderstandingTool(query: str) -> bool:
117
+ """Return True if the query seems to request a visualisation based on keywords."""
118
+ # Use LLM to understand intent instead of keyword matching
119
+ current_config = get_current_config()
120
+
121
+ # Prepend the instruction to the query
122
+ full_prompt = f"""You are a query classifier. Your task is to determine if a user query is requesting a data visualization.
123
+
124
+ IMPORTANT: Respond with ONLY 'true' or 'false' (lowercase, no quotes, no punctuation).
125
+
126
+ Classify as 'true' ONLY if the query explicitly asks for:
127
+ - A plot, chart, graph, visualization, or figure
128
+ - To "show" or "display" data visually
129
+ - To "create" or "generate" a visual representation
130
+ - Words like: plot, chart, graph, visualize, show, display, create, generate, draw
131
+
132
+ Classify as 'false' for:
133
+ - Data analysis without visualization requests
134
+ - Statistical calculations, aggregations, filtering, sorting
135
+ - Questions about data content, counts, summaries
136
+ - Requests for tables, dataframes, or text results
137
+
138
+ User query: {query}"""
139
+
140
+ messages = [
141
+ {"role": "system", "content": current_config.REASONING_FALSE},
142
+ {"role": "user", "content": full_prompt}
143
+ ]
144
+
145
+ response = client.chat.completions.create(
146
+ model=current_config.MODEL_NAME,
147
+ messages=messages,
148
+ temperature=current_config.QUERY_UNDERSTANDING_TEMPERATURE,
149
+ max_tokens=current_config.QUERY_UNDERSTANDING_MAX_TOKENS # We only need a short response
150
+ )
151
+
152
+ # Extract the response and convert to boolean
153
+
154
+ intent_response = response.choices[0].message.content.strip().lower()
155
+
156
+ return intent_response == "true"
157
+
158
+ # === CodeGeneration TOOLS ============================================
159
+
160
+
161
+ # ------------------ CodeWritingTool ---------------------------------
162
+ def CodeWritingTool(cols: List[str], query: str) -> str:
163
+ """Generate a prompt for the LLM to write pandas-only code for a data query (no plotting)."""
164
+
165
+ return f"""
166
+
167
+ Given DataFrame `df` with columns:
168
+
169
+ {', '.join(cols)}
170
+
171
+ Write Python code (pandas **only**, no plotting) to answer:
172
+ "{query}"
173
+
174
+ Rules
175
+ -----
176
+ 1. Use pandas operations on `df` only.
177
+ 2. Rely only on the columns in the DataFrame.
178
+ 3. Assign the final result to `result`.
179
+ 4. Return your answer inside a single markdown fence that starts with ```python and ends with ```.
180
+ 5. Do not include any explanations, comments, or prose outside the code block.
181
+ 6. Use **df** as the sole data source. **Do not** read files, fetch data, or use Streamlit.
182
+ 7. Do **not** import any libraries (pandas is already imported as pd).
183
+ 8. Handle missing values (`dropna`) before aggregations.
184
+
185
+ Example
186
+ -----
187
+ ```python
188
+ result = df.groupby("some_column")["a_numeric_col"].mean().sort_values(ascending=False)
189
+ ```
190
+
191
+ """
192
+
193
+
194
+ # ------------------ PlotCodeGeneratorTool ---------------------------
195
+ def PlotCodeGeneratorTool(cols: List[str], query: str) -> str:
196
+
197
+ """Generate a prompt for the LLM to write pandas + matplotlib code for a plot based on the query and columns."""
198
+
199
+ return f"""
200
+
201
+ Given DataFrame `df` with columns:
202
+
203
+ {', '.join(cols)}
204
+
205
+ Write Python code using pandas **and matplotlib** (as plt) to answer:
206
+ "{query}"
207
+
208
+ Rules
209
+ -----
210
+ 1. Use pandas for data manipulation and matplotlib.pyplot (as plt) for plotting.
211
+ 2. Rely only on the columns in the DataFrame.
212
+ 3. Assign the final result (DataFrame, Series, scalar *or* matplotlib Figure) to a variable named `result`.
213
+ 4. Create only ONE relevant plot. Set `figsize={DEFAULT_FIGSIZE}`, add title/labels.
214
+ 5. Return your answer inside a single markdown fence that starts with ```python and ends with ```.
215
+ 6. Do not include any explanations, comments, or prose outside the code block.
216
+ 7. Handle missing values (`dropna`) before plotting/aggregations.
217
+
218
+ """
219
+
220
+
221
+ # === CodeGenerationAgent ==============================================
222
+
223
+ def CodeGenerationAgent(query: str, df: pd.DataFrame, chat_context: Optional[str] = None):
224
+ """Selects the appropriate code generation tool and gets code from the LLM for the user's query."""
225
+
226
+ should_plot = QueryUnderstandingTool(query)
227
+
228
+ prompt = PlotCodeGeneratorTool(df.columns.tolist(), query) if should_plot else CodeWritingTool(df.columns.tolist(), query)
229
+
230
+ # Prepend the instruction to the query
231
+ context_section = f"\nConversation context (recent user turns):\n{chat_context}\n" if chat_context else ""
232
+
233
+ full_prompt = f"""You are a senior Python data analyst who writes clean, efficient code.
234
+ Solve the given problem with optimal pandas operations. Be concise and focused.
235
+ Your response must contain ONLY a properly-closed ```python code block with no explanations before or after (starts with ```python and ends with ```).
236
+ Ensure your solution is correct, handles edge cases, and follows best practices for data analysis.
237
+ If the latest user request references prior results ambiguously (e.g., "it", "that", "same groups"), infer intent from the conversation context and choose the most reasonable interpretation. {context_section}{prompt}"""
238
+
239
+ current_config = get_current_config()
240
+
241
+ messages = [
242
+ {"role": "system", "content": current_config.REASONING_FALSE},
243
+ {"role": "user", "content": full_prompt}
244
+ ]
245
+
246
+ response = client.chat.completions.create(
247
+ model=current_config.MODEL_NAME,
248
+ messages=messages,
249
+ temperature=current_config.CODE_GENERATION_TEMPERATURE,
250
+ max_tokens=current_config.CODE_GENERATION_MAX_TOKENS
251
+ )
252
+
253
+ full_response = response.choices[0].message.content
254
+
255
+ code = extract_first_code_block(full_response)
256
+ return code, should_plot, ""
257
+
258
+ # === ExecutionAgent ====================================================
259
+
260
+ def ExecutionAgent(code: str, df: pd.DataFrame, should_plot: bool):
261
+ """Executes the generated code in a controlled environment and returns the result or error message."""
262
+
263
+ # Set up execution environment with all necessary modules
264
+ env = {
265
+ "pd": pd,
266
+ "df": df
267
+ }
268
+
269
+ if should_plot:
270
+ plt.rcParams["figure.dpi"] = DEFAULT_DPI # Set default DPI for all figures
271
+ env["plt"] = plt
272
+ env["io"] = io
273
+
274
+ try:
275
+ # Execute the code in the environment
276
+ exec(code, {}, env)
277
+ result = env.get("result", None)
278
+
279
+ # If no result was assigned, return the last expression
280
+ if result is None:
281
+ # Try to get the last executed expression
282
+ if "result" not in env:
283
+ return "No result was assigned to 'result' variable"
284
+
285
+ return result
286
+ except Exception as exc:
287
+ return f"Error executing code: {exc}"
288
+
289
+ # === ReasoningCurator TOOL =========================================
290
+ def ReasoningCurator(query: str, result: Any) -> str:
291
+ """Builds and returns the LLM prompt for reasoning about the result."""
292
+ is_error = isinstance(result, str) and result.startswith("Error executing code")
293
+ is_plot = isinstance(result, (plt.Figure, plt.Axes))
294
+
295
+ if is_error:
296
+ desc = result
297
+ elif is_plot:
298
+ title = ""
299
+ if isinstance(result, plt.Figure):
300
+ title = result._suptitle.get_text() if result._suptitle else ""
301
+ elif isinstance(result, plt.Axes):
302
+ title = result.get_title()
303
+ desc = f"[Plot Object: {title or 'Chart'}]"
304
+ else:
305
+ desc = str(result)[:MAX_RESULT_DISPLAY_LENGTH]
306
+
307
+ if is_plot:
308
+ prompt = f'''
309
+ The user asked: "{query}".
310
+ Below is a description of the plot result:
311
+ {desc}
312
+ Explain in 2–3 concise sentences what the chart shows (no code talk).'''
313
+ else:
314
+ prompt = f'''
315
+ The user asked: "{query}".
316
+ The result value is: {desc}
317
+ Explain in 2–3 concise sentences what this tells about the data (no mention of charts).'''
318
+ return prompt
319
+
320
+ # === ReasoningAgent (streaming) =========================================
321
+ def ReasoningAgent(query: str, result: Any):
322
+ """Streams the LLM's reasoning about the result (plot or value) and extracts model 'thinking' and final explanation."""
323
+ current_config = get_current_config()
324
+ prompt = ReasoningCurator(query, result)
325
+
326
+ # Streaming LLM call
327
+ response = client.chat.completions.create(
328
+ model=current_config.MODEL_NAME,
329
+ messages=[
330
+ {"role": "system", "content": current_config.REASONING_TRUE},
331
+ {"role": "user", "content": "You are an insightful data analyst. " + prompt}
332
+ ],
333
+ temperature=current_config.REASONING_TEMPERATURE,
334
+ max_tokens=current_config.REASONING_MAX_TOKENS,
335
+ stream=True
336
+ )
337
+
338
+ # Stream and display thinking
339
+ thinking_placeholder = st.empty()
340
+ full_response = ""
341
+ thinking_content = ""
342
+ in_think = False
343
+
344
+ for chunk in response:
345
+ if chunk.choices[0].delta.content is not None:
346
+ token = chunk.choices[0].delta.content
347
+ full_response += token
348
+
349
+ # Simple state machine to extract <think>...</think> as it streams
350
+ if "<think>" in token:
351
+ in_think = True
352
+ token = token.split("<think>", 1)[1]
353
+ if "</think>" in token:
354
+ token = token.split("</think>", 1)[0]
355
+ in_think = False
356
+ if in_think or ("<think>" in full_response and not "</think>" in full_response):
357
+ thinking_content += token
358
+ thinking_placeholder.markdown(
359
+ f'<details class="thinking" open><summary>🤔 Model Thinking</summary><pre>{thinking_content}</pre></details>',
360
+ unsafe_allow_html=True
361
+ )
362
+
363
+ # After streaming, extract final reasoning (outside <think>...</think>)
364
+ cleaned = re.sub(r"<think>.*?</think>", "", full_response, flags=re.DOTALL).strip()
365
+ return thinking_content, cleaned
366
+
367
+ # === DataFrameSummary TOOL (pandas only) =========================================
368
+ def DataFrameSummaryTool(df: pd.DataFrame) -> str:
369
+ """Generate a summary prompt string for the LLM based on the DataFrame."""
370
+ prompt = f"""
371
+ Given a dataset with {len(df)} rows and {len(df.columns)} columns:
372
+ Columns: {', '.join(df.columns)}
373
+ Data types: {df.dtypes.to_dict()}
374
+ Missing values: {df.isnull().sum().to_dict()}
375
+
376
+ Provide:
377
+ 1. A brief description of what this dataset contains
378
+ 2. 3-4 possible data analysis questions that could be explored
379
+ Keep it concise and focused."""
380
+ return prompt
381
+
382
+ # === DataInsightAgent (upload-time only) ===============================
383
+
384
+ def DataInsightAgent(df: pd.DataFrame) -> str:
385
+ """Uses the LLM to generate a brief summary and possible questions for the uploaded dataset."""
386
+ current_config = get_current_config()
387
+ prompt = DataFrameSummaryTool(df)
388
+ try:
389
+ response = client.chat.completions.create(
390
+ model=current_config.MODEL_NAME,
391
+ messages=[
392
+ {"role": "system", "content": current_config.REASONING_FALSE},
393
+ {"role": "user", "content": "You are a data analyst providing brief, focused insights. " + prompt}
394
+ ],
395
+ temperature=current_config.INSIGHTS_TEMPERATURE,
396
+ max_tokens=current_config.INSIGHTS_MAX_TOKENS
397
+ )
398
+ return response.choices[0].message.content
399
+ except Exception as exc:
400
+ raise Exception(f"Error generating dataset insights: {exc}")
401
+
402
+ # === Helpers ===========================================================
403
+
404
+ def extract_first_code_block(text: str) -> str:
405
+ """Extracts the first Python code block from a markdown-formatted string."""
406
+ start = text.find("```python")
407
+ if start == -1:
408
+ return ""
409
+ start += len("```python")
410
+ end = text.find("```", start)
411
+ if end == -1:
412
+ return ""
413
+ return text[start:end].strip()
414
+
415
+ # === Main Streamlit App ===============================================
416
+
417
+ def main():
418
+ st.set_page_config(layout="wide")
419
+ if "plots" not in st.session_state:
420
+ st.session_state.plots = []
421
+ if "current_model" not in st.session_state:
422
+ st.session_state.current_model = DEFAULT_MODEL
423
+
424
+ # Page logo at top right corner, large and clickable
425
+ st.markdown(
426
+ """
427
+ <div style='position: absolute; top: 20px; right: 30px; z-index: 999;'>
428
+ <a href='https://www.linkedin.com/in/thiresh-sidda/' target='_blank'>
429
+ <img src='https://ih1.redbubble.net/image.1849728168.3104/raf,360x360,075,t,fafafa:ca443f4786.jpg' alt='Logo' style='height:120px; border-radius:20px; box-shadow:0 2px 12px rgba(0,0,0,0.15);'>
430
+ </a>
431
+ </div>
432
+ """,
433
+ unsafe_allow_html=True
434
+ )
435
+ # Main title centered with large font and GIF
436
+ st.markdown(
437
+ """
438
+ <div style='display: flex; align-items: center; justify-content: center; margin-bottom: 30px;'>
439
+ <span style='color:#1976D2; font-weight:bold; font-size:3.5em; margin-right:30px;'>Data Analysis Agent</span>
440
+ <img src='https://cdn.dribbble.com/userupload/23161671/file/original-4c7894556285d8f223ab21fd10554fe4.gif' alt='GIF' style='height:120px;'>
441
+ </div>
442
+ """,
443
+ unsafe_allow_html=True
444
+ )
445
+
446
+ medium_blue = "#1976D2" # Medium blue color
447
+
448
+ # Move left panel to sidebar
449
+ with st.sidebar:
450
+ st.markdown(f"<span style='color:{medium_blue}; font-weight:bold; font-size:1.5em;'>Insights Generator</span>", unsafe_allow_html=True)
451
+ available_models = list(MODEL_CONFIGS.keys())
452
+ model_display_names = {key: MODEL_CONFIGS[key].MODEL_PRINT_NAME for key in available_models}
453
+ selected_model = st.selectbox(
454
+ "Select Model",
455
+ options=available_models,
456
+ format_func=lambda x: model_display_names[x],
457
+ index=available_models.index(st.session_state.current_model)
458
+ )
459
+ display_config = MODEL_CONFIGS[selected_model]
460
+ file = st.file_uploader("Choose CSV", type=["csv"], key="csv_uploader")
461
+ # Update configuration if model changed
462
+ if selected_model != st.session_state.current_model:
463
+ st.session_state.current_model = selected_model
464
+ new_config = MODEL_CONFIGS[selected_model]
465
+ if "messages" in st.session_state:
466
+ st.session_state.messages = []
467
+ if "plots" in st.session_state:
468
+ st.session_state.plots = []
469
+ if "df" in st.session_state and file is not None:
470
+ with st.spinner("Generating dataset insights with new model …"):
471
+ try:
472
+ st.session_state.insights = DataInsightAgent(st.session_state.df)
473
+ st.success(f"Insights updated with {new_config.MODEL_PRINT_NAME}")
474
+ except Exception as e:
475
+ st.error(f"Error updating insights: {str(e)}")
476
+ if "insights" in st.session_state:
477
+ del st.session_state.insights
478
+ st.rerun()
479
+ if not file and "df" in st.session_state and "current_file" in st.session_state:
480
+ del st.session_state.df
481
+ del st.session_state.current_file
482
+ if "insights" in st.session_state:
483
+ del st.session_state.insights
484
+ st.rerun()
485
+ if file:
486
+ if ("df" not in st.session_state) or (st.session_state.get("current_file") != file.name):
487
+ st.session_state.df = pd.read_csv(file)
488
+ st.session_state.current_file = file.name
489
+ st.session_state.messages = []
490
+ with st.spinner("Generating dataset insights …"):
491
+ try:
492
+ st.session_state.insights = DataInsightAgent(st.session_state.df)
493
+ except Exception as e:
494
+ st.error(f"Error generating insights: {str(e)}")
495
+ elif "insights" not in st.session_state:
496
+ with st.spinner("Generating dataset insights …"):
497
+ try:
498
+ st.session_state.insights = DataInsightAgent(st.session_state.df)
499
+ except Exception as e:
500
+ st.error(f"Error generating insights: {str(e)}")
501
+ if "df" in st.session_state:
502
+ st.markdown(f"<span style='color:{medium_blue}; font-weight:bold; font-size:1.2em;'>Your Dataset Insights</span>", unsafe_allow_html=True)
503
+ if "insights" in st.session_state and st.session_state.insights:
504
+ st.dataframe(st.session_state.df.head())
505
+ st.markdown(f"<span style='color:{medium_blue};'>{st.session_state.insights}</span>", unsafe_allow_html=True)
506
+ current_config_left = get_current_config()
507
+ #st.markdown(f"*<span style='color: grey; font-style: italic;'>Generated with {current_config_left.MODEL_PRINT_NAME}</span>*", unsafe_allow_html=True)
508
+ else:
509
+ st.warning("No insights available.")
510
+ else:
511
+ st.info("Upload a CSV to begin chatting with your data.")
512
+
513
+ with st.container():
514
+ st.markdown(
515
+ f"""
516
+ <div style='display: flex; align-items: center; justify-content: flex-start; margin-bottom: 10px;'>
517
+ <span style='color:{medium_blue}; font-weight:bold; font-size:2em; margin-right:20px;'>Chat with your data</span>
518
+ <img src='https://i.pinimg.com/originals/5f/d5/58/5fd558f8b7a4f9e2138709cbe63c7052.gif' alt='Chat GIF' style='height:48px;'>
519
+ </div>
520
+ """,
521
+ unsafe_allow_html=True
522
+ )
523
+ if "df" in st.session_state:
524
+ current_config_right = get_current_config()
525
+ st.markdown(f"*<span style='color: grey; font-style: italic;'>Using {current_config_right.MODEL_PRINT_NAME}</span>*", unsafe_allow_html=True)
526
+ if "messages" not in st.session_state:
527
+ st.session_state.messages = []
528
+
529
+ clear_col1, clear_col2 = st.columns([9,1])
530
+ with clear_col2:
531
+ if st.button("Clear chat"):
532
+ st.session_state.messages = []
533
+ st.session_state.plots = []
534
+ st.rerun()
535
+
536
+ for msg in st.session_state.messages:
537
+ with st.chat_message(msg["role"]):
538
+ st.markdown(f"<span style='color:{medium_blue}; font-size:1.1em;'>{msg['content']}</span>", unsafe_allow_html=True)
539
+ if msg.get("plot_index") is not None:
540
+ idx = msg["plot_index"]
541
+ if 0 <= idx < len(st.session_state.plots):
542
+ st.pyplot(st.session_state.plots[idx], use_container_width=False)
543
+
544
+ if "df" in st.session_state:
545
+ if user_q := st.chat_input("Ask about your data…"):
546
+ st.session_state.messages.append({"role": "user", "content": user_q})
547
+ with st.spinner("Working …"):
548
+ recent_user_turns = [m["content"] for m in st.session_state.messages if m["role"] == "user"][-3:]
549
+ context_text = "\n".join(recent_user_turns[:-1]) if len(recent_user_turns) > 1 else None
550
+ code, should_plot_flag, code_thinking = CodeGenerationAgent(user_q, st.session_state.df, context_text)
551
+ result_obj = ExecutionAgent(code, st.session_state.df, should_plot_flag)
552
+ raw_thinking, reasoning_txt = ReasoningAgent(user_q, result_obj)
553
+ reasoning_txt = reasoning_txt.replace("`", "")
554
+
555
+ is_plot = isinstance(result_obj, (plt.Figure, plt.Axes))
556
+ plot_idx = None
557
+ if is_plot:
558
+ fig = result_obj.figure if isinstance(result_obj, plt.Axes) else result_obj
559
+ st.session_state.plots.append(fig)
560
+ plot_idx = len(st.session_state.plots) - 1
561
+ header = "Here is the visualization you requested:"
562
+ elif isinstance(result_obj, (pd.DataFrame, pd.Series)):
563
+ header = f"Result: {len(result_obj)} rows" if isinstance(result_obj, pd.DataFrame) else "Result series"
564
+ else:
565
+ header = f"Result: {result_obj}"
566
+
567
+ thinking_html = ""
568
+ if raw_thinking:
569
+ thinking_html = (
570
+ '<details class="thinking">'
571
+ '<summary>🧠 Reasoning</summary>'
572
+ f'<pre>{raw_thinking}</pre>'
573
+ '</details>'
574
+ )
575
+
576
+ explanation_html = reasoning_txt
577
+
578
+ code_html = (
579
+ '<details class="code">'
580
+ '<summary>View code</summary>'
581
+ '<pre><code class="language-python">'
582
+ f'{code}'
583
+ '</code></pre>'
584
+ '</details>'
585
+ )
586
+ assistant_msg = f"{thinking_html}{explanation_html}\n\n{code_html}"
587
+
588
+ st.session_state.messages.append({
589
+ "role": "assistant",
590
+ "content": assistant_msg,
591
+ "plot_index": plot_idx
592
+ })
593
+ st.rerun()
594
+
595
+ if __name__ == "__main__":
596
+ main()
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
1
+ streamlit>=1.32.0
2
+ pandas>=2.2.0
3
+ matplotlib>=3.8.0
4
+ seaborn>=0.13.0
5
+ openai>=1.12.0
6
+ watchdog>=3.0.0