|
|
import os |
|
|
|
|
|
import streamlit as st |
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
import pickle |
|
|
|
|
|
import base64 |
|
|
|
|
|
from io import BytesIO, StringIO |
|
|
|
|
|
import sys |
|
|
|
|
|
import operator |
|
|
|
|
|
from typing import Literal, Sequence, TypedDict, Annotated, List, Dict, Tuple |
|
|
|
|
|
import tempfile |
|
|
|
|
|
import shutil |
|
|
|
|
|
import plotly.io as pio |
|
|
|
|
|
import io |
|
|
|
|
|
import re |
|
|
|
|
|
import json |
|
|
|
|
|
import openai |
|
|
|
|
|
|
|
|
|
|
|
import base64 |
|
|
|
|
|
from datetime import datetime |
|
|
|
|
|
from reportlab.lib.pagesizes import letter |
|
|
|
|
|
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image |
|
|
|
|
|
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle |
|
|
|
|
|
from reportlab.lib.units import inch |
|
|
|
|
|
from PIL import Image as PILImage |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage |
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
from langchain_experimental.utilities import PythonREPL |
|
|
|
|
|
from langgraph.prebuilt import ToolInvocation, ToolExecutor |
|
|
|
|
|
from langchain_core.tools import tool |
|
|
|
|
|
from langgraph.prebuilt import InjectedState |
|
|
|
|
|
from langgraph.graph import StateGraph, END |
|
|
|
|
|
from reportlab.platypus import PageBreak |
|
|
|
|
|
from PIL import Image as PILImage |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'ai_provider' not in st.session_state: |
|
|
|
|
|
st.session_state.ai_provider = "openai" |
|
|
|
|
|
|
|
|
|
|
|
if 'api_key' not in st.session_state: |
|
|
|
|
|
st.session_state.api_key = "" |
|
|
|
|
|
|
|
|
|
|
|
if 'selected_model' not in st.session_state: |
|
|
|
|
|
st.session_state.selected_model = "gpt-4" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OPENAI_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-4-mini", "gpt-3.5-turbo"] |
|
|
|
|
|
GROQ_MODELS = ["llama3.3-70b-versatile", "gemma2-9b-it", "llama-3-8b-8192"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'temp_dir' not in st.session_state: |
|
|
|
|
|
st.session_state.temp_dir = tempfile.mkdtemp() |
|
|
|
|
|
st.session_state.images_dir = os.path.join(st.session_state.temp_dir, "images/plotly_figures/pickle") |
|
|
|
|
|
os.makedirs(st.session_state.images_dir, exist_ok=True) |
|
|
|
|
|
print(f"Created temporary directory: {st.session_state.temp_dir}") |
|
|
|
|
|
print(f"Created images directory: {st.session_state.images_dir}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = """## Role |
|
|
|
|
|
You are a professional data scientist helping a non-technical user understand, analyze, and visualize their data. |
|
|
|
|
|
|
|
|
|
|
|
## Capabilities |
|
|
|
|
|
1. **Execute python code** using the `complete_python_task` tool. |
|
|
|
|
|
|
|
|
|
|
|
## Goals |
|
|
|
|
|
1. Understand the user's objectives clearly. |
|
|
|
|
|
2. Take the user on a data analysis journey, iterating to find the best way to visualize or analyse their data to solve their problems. |
|
|
|
|
|
3. Investigate if the goal is achievable by running Python code via the `python_code` field. |
|
|
|
|
|
4. Gain input from the user at every step to ensure the analysis is on the right track and to understand business nuances. |
|
|
|
|
|
|
|
|
|
|
|
## Code Guidelines |
|
|
|
|
|
- **ALL INPUT DATA IS LOADED ALREADY**, so use the provided variable names to access the data. |
|
|
|
|
|
- **VARIABLES PERSIST BETWEEN RUNS**, so reuse previously defined variables if needed. |
|
|
|
|
|
- **TO SEE CODE OUTPUT**, use `print()` statements. You won't be able to see outputs of `pd.head()`, `pd.describe()` etc. otherwise. |
|
|
|
|
|
- **ONLY USE THE FOLLOWING LIBRARIES**: |
|
|
|
|
|
- `pandas` |
|
|
|
|
|
- `sklearn` (including all major ML models) |
|
|
|
|
|
- `plotly` |
|
|
|
|
|
- `numpy` |
|
|
|
|
|
|
|
|
|
|
|
All these libraries are already imported for you. |
|
|
|
|
|
|
|
|
|
|
|
## Machine Learning Guidelines |
|
|
|
|
|
- For regression tasks: |
|
|
|
|
|
- Linear Regression: `LinearRegression` |
|
|
|
|
|
- Logistic Regression: `LogisticRegression` |
|
|
|
|
|
- Ridge Regression: `Ridge` |
|
|
|
|
|
- Lasso Regression: `Lasso` |
|
|
|
|
|
- Random Forest Regression: `RandomForestRegressor` |
|
|
|
|
|
|
|
|
|
|
|
- For classification tasks: |
|
|
|
|
|
- Logistic Regression: `LogisticRegression` |
|
|
|
|
|
- Decision Trees: `DecisionTreeClassifier` |
|
|
|
|
|
- Random Forests: `RandomForestClassifier` |
|
|
|
|
|
- Support Vector Machines: `SVC` |
|
|
|
|
|
- K-Nearest Neighbors: `KNeighborsClassifier` |
|
|
|
|
|
- Naive Bayes: `GaussianNB` |
|
|
|
|
|
|
|
|
|
|
|
- For clustering: |
|
|
|
|
|
- K-Means: `KMeans` |
|
|
|
|
|
- DBSCAN: `DBSCAN` |
|
|
|
|
|
|
|
|
|
|
|
- For dimensionality reduction: |
|
|
|
|
|
- PCA: `PCA` |
|
|
|
|
|
|
|
|
|
|
|
- Always preprocess data appropriately: |
|
|
|
|
|
- Scale numerical features with `StandardScaler` or `MinMaxScaler` |
|
|
|
|
|
- Encode categorical variables with `OneHotEncoder` when needed |
|
|
|
|
|
- Handle missing values with `SimpleImputer` |
|
|
|
|
|
|
|
|
|
|
|
- Always split data into training and testing sets using `train_test_split` |
|
|
|
|
|
- Evaluate models using appropriate metrics: |
|
|
|
|
|
- For regression: `mean_squared_error`, `mean_absolute_error`, `r2_score` |
|
|
|
|
|
- For classification: `accuracy_score`, `confusion_matrix`, `classification_report` |
|
|
|
|
|
- For clustering: `silhouette_score` |
|
|
|
|
|
|
|
|
|
|
|
- Consider using `cross_val_score` for more robust evaluation |
|
|
|
|
|
- Visualize ML results with plotly when possible |
|
|
|
|
|
|
|
|
|
|
|
## Plotting Guidelines |
|
|
|
|
|
- Always use the `plotly` library for plotting. |
|
|
|
|
|
- Store all plotly figures inside a `plotly_figures` list, they will be saved automatically. |
|
|
|
|
|
- Do not try and show the plots inline with `fig.show()`. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
|
|
|
|
messages: Annotated[Sequence[BaseMessage], operator.add] |
|
|
|
|
|
input_data: Annotated[List[Dict], operator.add] |
|
|
|
|
|
intermediate_outputs: Annotated[List[dict], operator.add] |
|
|
|
|
|
current_variables: dict |
|
|
|
|
|
output_image_paths: Annotated[List[str], operator.add] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'in_memory_datasets' not in st.session_state: |
|
|
|
|
|
st.session_state.in_memory_datasets = {} |
|
|
|
|
|
|
|
|
|
|
|
if 'persistent_vars' not in st.session_state: |
|
|
|
|
|
st.session_state.persistent_vars = {} |
|
|
|
|
|
|
|
|
|
|
|
if 'dataset_metadata_list' not in st.session_state: |
|
|
|
|
|
st.session_state.dataset_metadata_list = [] |
|
|
|
|
|
|
|
|
|
|
|
if 'chat_history' not in st.session_state: |
|
|
|
|
|
st.session_state.chat_history = [] |
|
|
|
|
|
|
|
|
|
|
|
if 'dashboard_plots' not in st.session_state: |
|
|
|
|
|
st.session_state.dashboard_plots = [None, None, None, None] |
|
|
|
|
|
|
|
|
|
|
|
if 'columns' not in st.session_state: |
|
|
|
|
|
st.session_state.columns = ["No columns available"] |
|
|
|
|
|
|
|
|
|
|
|
if 'custom_plots_to_save' not in st.session_state: |
|
|
|
|
|
st.session_state.custom_plots_to_save = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repl = PythonREPL() |
|
|
|
|
|
plotly_saving_code = """import pickle |
|
|
|
|
|
|
|
|
|
|
|
import uuid |
|
|
|
|
|
import os |
|
|
|
|
|
for figure in plotly_figures: |
|
|
|
|
|
pickle_filename = f"{images_dir}/{uuid.uuid4()}.pickle" |
|
|
|
|
|
with open(pickle_filename, 'wb') as f: |
|
|
|
|
|
pickle.dump(figure, f) |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
@tool |
|
|
|
|
|
def complete_python_task( |
|
|
|
|
|
graph_state: Annotated[dict, InjectedState], |
|
|
|
|
|
thought: str, |
|
|
|
|
|
python_code: str |
|
|
|
|
|
) -> Tuple[str, dict]: |
|
|
|
|
|
"""Execute Python code for data analysis and visualization.""" |
|
|
|
|
|
|
|
|
|
|
|
current_variables = graph_state.get("current_variables", {}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for input_dataset in graph_state.get("input_data", []): |
|
|
|
|
|
var_name = input_dataset.get("variable_name") |
|
|
|
|
|
if var_name and var_name not in current_variables and var_name in st.session_state.in_memory_datasets: |
|
|
|
|
|
print(f"Loading {var_name} from in-memory storage") |
|
|
|
|
|
current_variables[var_name] = st.session_state.in_memory_datasets[var_name] |
|
|
|
|
|
current_image_pickle_files = os.listdir(st.session_state.images_dir) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
old_stdout = sys.stdout |
|
|
|
|
|
sys.stdout = StringIO() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exec_globals = globals().copy() |
|
|
|
|
|
exec_globals.update(st.session_state.persistent_vars) |
|
|
|
|
|
exec_globals.update(current_variables) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sklearn |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge, Lasso |
|
|
|
|
|
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier |
|
|
|
|
|
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor |
|
|
|
|
|
from sklearn.svm import SVC, SVR |
|
|
|
|
|
from sklearn.naive_bayes import GaussianNB |
|
|
|
|
|
from sklearn.decomposition import PCA |
|
|
|
|
|
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor |
|
|
|
|
|
from sklearn.cluster import KMeans, DBSCAN |
|
|
|
|
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder |
|
|
|
|
|
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV |
|
|
|
|
|
from sklearn.metrics import ( |
|
|
|
|
|
accuracy_score, confusion_matrix, classification_report, |
|
|
|
|
|
mean_squared_error, r2_score, mean_absolute_error, silhouette_score |
|
|
|
|
|
) |
|
|
|
|
|
from sklearn.pipeline import Pipeline |
|
|
|
|
|
from sklearn.impute import SimpleImputer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exec_globals.update({ |
|
|
|
|
|
"plotly_figures": [], |
|
|
|
|
|
"images_dir": st.session_state.images_dir, |
|
|
|
|
|
"np": np, |
|
|
|
|
|
|
|
|
|
|
|
"LinearRegression": LinearRegression, |
|
|
|
|
|
"LogisticRegression": LogisticRegression, |
|
|
|
|
|
"Ridge": Ridge, |
|
|
|
|
|
"Lasso": Lasso, |
|
|
|
|
|
|
|
|
|
|
|
"DecisionTreeClassifier": DecisionTreeClassifier, |
|
|
|
|
|
"DecisionTreeRegressor": DecisionTreeRegressor, |
|
|
|
|
|
"RandomForestClassifier": RandomForestClassifier, |
|
|
|
|
|
"RandomForestRegressor": RandomForestRegressor, |
|
|
|
|
|
"GradientBoostingClassifier": GradientBoostingClassifier, |
|
|
|
|
|
|
|
|
|
|
|
"SVC": SVC, |
|
|
|
|
|
"SVR": SVR, |
|
|
|
|
|
|
|
|
|
|
|
"GaussianNB": GaussianNB, |
|
|
|
|
|
"PCA": PCA, |
|
|
|
|
|
"KNeighborsClassifier": KNeighborsClassifier, |
|
|
|
|
|
"KNeighborsRegressor": KNeighborsRegressor, |
|
|
|
|
|
"KMeans": KMeans, |
|
|
|
|
|
"DBSCAN": DBSCAN, |
|
|
|
|
|
|
|
|
|
|
|
"StandardScaler": StandardScaler, |
|
|
|
|
|
"MinMaxScaler": MinMaxScaler, |
|
|
|
|
|
"OneHotEncoder": OneHotEncoder, |
|
|
|
|
|
"SimpleImputer": SimpleImputer, |
|
|
|
|
|
|
|
|
|
|
|
"train_test_split": train_test_split, |
|
|
|
|
|
"cross_val_score": cross_val_score, |
|
|
|
|
|
"GridSearchCV": GridSearchCV, |
|
|
|
|
|
"accuracy_score": accuracy_score, |
|
|
|
|
|
"confusion_matrix": confusion_matrix, |
|
|
|
|
|
"classification_report": classification_report, |
|
|
|
|
|
"mean_squared_error": mean_squared_error, |
|
|
|
|
|
"r2_score": r2_score, |
|
|
|
|
|
"mean_absolute_error": mean_absolute_error, |
|
|
|
|
|
"silhouette_score": silhouette_score, |
|
|
|
|
|
|
|
|
|
|
|
"Pipeline": Pipeline |
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
exec(python_code, exec_globals) |
|
|
|
|
|
|
|
|
|
|
|
st.session_state.persistent_vars.update({k: v for k, v in exec_globals.items() if k not in globals()}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = sys.stdout.getvalue() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sys.stdout = old_stdout |
|
|
|
|
|
|
|
|
|
|
|
updated_state = { |
|
|
|
|
|
"intermediate_outputs": [{"thought": thought, "code": python_code, "output": output}], |
|
|
|
|
|
"current_variables": st.session_state.persistent_vars |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if 'plotly_figures' in exec_globals and exec_globals['plotly_figures']: |
|
|
|
|
|
exec(plotly_saving_code, exec_globals) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_image_folder_contents = os.listdir(st.session_state.images_dir) |
|
|
|
|
|
new_image_files = [file for file in new_image_folder_contents if file not in current_image_pickle_files] |
|
|
|
|
|
|
|
|
|
|
|
if new_image_files: |
|
|
|
|
|
updated_state["output_image_paths"] = new_image_files |
|
|
|
|
|
st.session_state.persistent_vars["plotly_figures"] = [] |
|
|
|
|
|
return output, updated_state |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
sys.stdout = old_stdout |
|
|
|
|
|
print(f"Error in complete_python_task: {str(e)}") |
|
|
|
|
|
return str(e), {"intermediate_outputs": [{"thought": thought, "code": python_code, "output": str(e)}]} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_llm(): |
|
|
|
|
|
api_key = st.session_state.api_key |
|
|
|
|
|
model = st.session_state.selected_model |
|
|
|
|
|
|
|
|
|
|
|
if not api_key: |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if st.session_state.ai_provider == "openai": |
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = api_key |
|
|
|
|
|
return ChatOpenAI(model=model, temperature=0) |
|
|
|
|
|
elif st.session_state.ai_provider == "groq": |
|
|
|
|
|
os.environ["GROQ_API_KEY"] = api_key |
|
|
|
|
|
|
|
|
|
|
|
from langchain_groq import ChatGroq |
|
|
|
|
|
return ChatGroq(model=model, temperature=0) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Error initializing LLM: {str(e)}") |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools = [complete_python_task] |
|
|
|
|
|
tool_executor = ToolExecutor(tools) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chat_template = ChatPromptTemplate.from_messages([ |
|
|
|
|
|
("system", SYSTEM_PROMPT), |
|
|
|
|
|
("placeholder", "{messages}"), |
|
|
|
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
def create_data_summary(state: AgentState) -> str: |
|
|
|
|
|
summary = "" |
|
|
|
|
|
variables = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for d in state.get("input_data", []): |
|
|
|
|
|
var_name = d.get("variable_name") |
|
|
|
|
|
if var_name: |
|
|
|
|
|
|
|
|
|
|
|
variables.append(var_name) |
|
|
|
|
|
summary += f"\n\nVariable: {var_name}\n" |
|
|
|
|
|
summary += f"Description: {d.get('data_description', 'No description')}\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if var_name in st.session_state.in_memory_datasets: |
|
|
|
|
|
df = st.session_state.in_memory_datasets[var_name] |
|
|
|
|
|
summary += "\nSample Data (first 5 rows):\n" |
|
|
|
|
|
summary += df.head(5).to_string() |
|
|
|
|
|
|
|
|
|
|
|
if "current_variables" in state: |
|
|
|
|
|
remaining_variables = [v for v in state["current_variables"] if v not in variables and not v.startswith("_")] |
|
|
|
|
|
|
|
|
|
|
|
for v in remaining_variables: |
|
|
|
|
|
|
|
|
|
|
|
var_value = state["current_variables"].get(v) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(var_value, pd.DataFrame): |
|
|
|
|
|
summary += f"\n\nVariable: {v} (DataFrame with shape {var_value.shape})" |
|
|
|
|
|
else: |
|
|
|
|
|
summary += f"\n\nVariable: {v}" |
|
|
|
|
|
return summary |
|
|
|
|
|
|
|
|
|
|
|
def route_to_tools(state: AgentState) -> Literal["tools", "__end__"]: |
|
|
|
|
|
"""Determine if we should route to tools or end the chain""" |
|
|
|
|
|
if messages := state.get("messages", []): |
|
|
|
|
|
ai_message = messages[-1] |
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError(f"No messages found in input state to tool_edge: {state}") |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: |
|
|
|
|
|
return "tools" |
|
|
|
|
|
|
|
|
|
|
|
return "__end__" |
|
|
|
|
|
|
|
|
|
|
|
def call_model(state: AgentState): |
|
|
|
|
|
"""Call the LLM to get a response""" |
|
|
|
|
|
current_data_template = """The following data is available:\n{data_summary}""" |
|
|
|
|
|
current_data_message = HumanMessage( |
|
|
|
|
|
content=current_data_template.format(data_summary=create_data_summary(state)) |
|
|
|
|
|
) |
|
|
|
|
|
messages = [current_data_message] + state["messages"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm = initialize_llm() |
|
|
|
|
|
if llm is None: |
|
|
|
|
|
return {"messages": [AIMessage(content="Please configure a valid API key and model in the settings tab.")]} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = llm.bind_tools(tools) |
|
|
|
|
|
model = chat_template | model |
|
|
|
|
|
|
|
|
|
|
|
llm_outputs = model.invoke({"messages": messages}) |
|
|
|
|
|
return {"messages": [llm_outputs], "intermediate_outputs": [current_data_message.content]} |
|
|
|
|
|
|
|
|
|
|
|
def call_tools(state: AgentState): |
|
|
|
|
|
"""Execute tools called by the LLM""" |
|
|
|
|
|
last_message = state["messages"][-1] |
|
|
|
|
|
tool_invocations = [] |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(last_message, AIMessage) and hasattr(last_message, 'tool_calls'): |
|
|
|
|
|
tool_invocations = [ |
|
|
|
|
|
ToolInvocation( |
|
|
|
|
|
tool=tool_call["name"], |
|
|
|
|
|
tool_input={**tool_call["args"], "graph_state": state} |
|
|
|
|
|
) for tool_call in last_message.tool_calls |
|
|
|
|
|
] |
|
|
|
|
|
responses = tool_executor.batch(tool_invocations, return_exceptions=True) |
|
|
|
|
|
|
|
|
|
|
|
tool_messages = [] |
|
|
|
|
|
state_updates = {} |
|
|
|
|
|
|
|
|
|
|
|
for tc, response in zip(last_message.tool_calls, responses): |
|
|
|
|
|
if isinstance(response, Exception): |
|
|
|
|
|
print(f"Exception in tool execution: {str(response)}") |
|
|
|
|
|
tool_messages.append(ToolMessage( |
|
|
|
|
|
content=f"Error: {str(response)}", |
|
|
|
|
|
name=tc["name"], |
|
|
|
|
|
tool_call_id=tc["id"] |
|
|
|
|
|
)) |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
message, updates = response |
|
|
|
|
|
tool_messages.append(ToolMessage( |
|
|
|
|
|
content=str(message), |
|
|
|
|
|
name=tc["name"], |
|
|
|
|
|
tool_call_id=tc["id"] |
|
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for key, value in updates.items(): |
|
|
|
|
|
if key in state_updates: |
|
|
|
|
|
if isinstance(value, list) and isinstance(state_updates[key], list): |
|
|
|
|
|
state_updates[key].extend(value) |
|
|
|
|
|
elif isinstance(value, dict) and isinstance(state_updates[key], dict): |
|
|
|
|
|
state_updates[key].update(value) |
|
|
|
|
|
else: |
|
|
|
|
|
state_updates[key] = value |
|
|
|
|
|
else: |
|
|
|
|
|
state_updates[key] = value |
|
|
|
|
|
|
|
|
|
|
|
if 'messages' not in state_updates: |
|
|
|
|
|
state_updates["messages"] = [] |
|
|
|
|
|
|
|
|
|
|
|
state_updates["messages"] = tool_messages |
|
|
|
|
|
return state_updates |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
workflow.add_node("agent", call_model) |
|
|
|
|
|
workflow.add_node("tools", call_tools) |
|
|
|
|
|
workflow.add_conditional_edges( |
|
|
|
|
|
"agent", |
|
|
|
|
|
route_to_tools, |
|
|
|
|
|
{ |
|
|
|
|
|
"tools": "tools", |
|
|
|
|
|
"__end__": END |
|
|
|
|
|
} |
|
|
|
|
|
) |
|
|
|
|
|
workflow.add_edge("tools", "agent") |
|
|
|
|
|
workflow.set_entry_point("agent") |
|
|
|
|
|
|
|
|
|
|
|
chain = workflow.compile() |
|
|
|
|
|
|
|
|
|
|
|
def process_file_upload(files): |
|
|
|
|
|
"""Process uploaded files and return dataframe previews and column names""" |
|
|
|
|
|
st.session_state.in_memory_datasets = {} |
|
|
|
|
|
st.session_state.dataset_metadata_list = [] |
|
|
|
|
|
st.session_state.persistent_vars.clear() |
|
|
|
|
|
|
|
|
|
|
|
if not files: |
|
|
|
|
|
return "No files uploaded.", [], ["No columns available"] |
|
|
|
|
|
|
|
|
|
|
|
results = [] |
|
|
|
|
|
all_columns = [] |
|
|
|
|
|
|
|
|
|
|
|
for file in files: |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
if file.name.endswith('.csv'): |
|
|
|
|
|
df = pd.read_csv(file) |
|
|
|
|
|
elif file.name.endswith(('.xls', '.xlsx')): |
|
|
|
|
|
df = pd.read_excel(file) |
|
|
|
|
|
else: |
|
|
|
|
|
results.append(f"Unsupported file format: {file.name}. Please upload CSV or Excel files.") |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
var_name = file.name.split('.')[0].replace('-', '_').replace(' ', '_').lower() |
|
|
|
|
|
st.session_state.in_memory_datasets[var_name] = df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_columns.extend(df.columns.tolist()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_metadata = { |
|
|
|
|
|
"variable_name": var_name, |
|
|
|
|
|
"data_path": "in_memory", |
|
|
|
|
|
"data_description": f"Dataset containing {df.shape[0]} rows and {df.shape[1]} columns. Columns: {', '.join(df.columns.tolist())}", |
|
|
|
|
|
"original_filename": file.name |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
st.session_state.dataset_metadata_list.append(dataset_metadata) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preview = f"### Dataset: {file.name}\nVariable name: `{var_name}`\n\n" |
|
|
|
|
|
preview += df.head(10).to_markdown() |
|
|
|
|
|
results.append(preview) |
|
|
|
|
|
print(f"Successfully processed {file.name}") |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Error processing {file.name}: {str(e)}") |
|
|
|
|
|
results.append(f"Error processing {file.name}: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_columns = [] |
|
|
|
|
|
seen = set() |
|
|
|
|
|
|
|
|
|
|
|
for col in all_columns: |
|
|
|
|
|
if col not in seen: |
|
|
|
|
|
seen.add(col) |
|
|
|
|
|
unique_columns.append(col) |
|
|
|
|
|
|
|
|
|
|
|
if not unique_columns: |
|
|
|
|
|
unique_columns = ["No columns available"] |
|
|
|
|
|
|
|
|
|
|
|
print(f"Found {len(unique_columns)} unique columns across datasets") |
|
|
|
|
|
return "\n\n".join(results), st.session_state.dataset_metadata_list, unique_columns |
|
|
|
|
|
|
|
|
|
|
|
def get_columns(): |
|
|
|
|
|
"""Directly gets columns from in-memory datasets""" |
|
|
|
|
|
all_columns = [] |
|
|
|
|
|
|
|
|
|
|
|
for var_name, df in st.session_state.in_memory_datasets.items(): |
|
|
|
|
|
if isinstance(df, pd.DataFrame): |
|
|
|
|
|
all_columns.extend(df.columns.tolist()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_columns = [] |
|
|
|
|
|
seen = set() |
|
|
|
|
|
|
|
|
|
|
|
for col in all_columns: |
|
|
|
|
|
if col not in seen: |
|
|
|
|
|
seen.add(col) |
|
|
|
|
|
unique_columns.append(col) |
|
|
|
|
|
|
|
|
|
|
|
if not unique_columns: |
|
|
|
|
|
unique_columns = ["No columns available"] |
|
|
|
|
|
|
|
|
|
|
|
print(f"Populating dropdowns with {len(unique_columns)} columns") |
|
|
|
|
|
return unique_columns |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import openai |
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
import json |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
def standard_clean(df): |
|
|
|
|
|
df.columns = [re.sub(r'\W+', '_', col).strip().lower() for col in df.columns] |
|
|
|
|
|
df.drop_duplicates(inplace=True) |
|
|
|
|
|
df.dropna(axis=1, how='all', inplace=True) |
|
|
|
|
|
df.dropna(axis=0, how='all', inplace=True) |
|
|
|
|
|
for col in df.select_dtypes(include='object').columns: |
|
|
|
|
|
df[col] = df[col].astype(str).str.strip() |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
|
|
|
def query_openai(prompt): |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
api_key = st.session_state.api_key |
|
|
|
|
|
model = st.session_state.selected_model |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.ai_provider == "openai": |
|
|
|
|
|
client = openai.OpenAI(api_key=api_key) |
|
|
|
|
|
response = client.chat.completions.create( |
|
|
|
|
|
model=model, |
|
|
|
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
|
|
|
temperature=0.7 |
|
|
|
|
|
) |
|
|
|
|
|
return response.choices[0].message.content |
|
|
|
|
|
elif st.session_state.ai_provider == "groq": |
|
|
|
|
|
from groq import Groq |
|
|
|
|
|
client = Groq(api_key=api_key) |
|
|
|
|
|
response = client.chat.completions.create( |
|
|
|
|
|
model=model, |
|
|
|
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
|
|
|
temperature=0.7 |
|
|
|
|
|
) |
|
|
|
|
|
return response.choices[0].message.content |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"API Error: {e}") |
|
|
|
|
|
return "{}" |
|
|
|
|
|
|
|
|
|
|
|
def llm_suggest_cleaning(df): |
|
|
|
|
|
sample = df.head(10).to_csv(index=False) |
|
|
|
|
|
prompt = f""" |
|
|
|
|
|
You are a professional data wrangler. Below is a sample of a messy dataset. |
|
|
|
|
|
|
|
|
|
|
|
Return a Python dictionary with the following keys: |
|
|
|
|
|
|
|
|
|
|
|
1. rename_columns – fix unclear or inconsistent column names |
|
|
|
|
|
2. convert_types – correct datatypes: int, float, str, or date |
|
|
|
|
|
3. fill_missing – use 'mean', 'median', 'mode', or a constant like 'Unknown' or 0 |
|
|
|
|
|
4. value_map – map inconsistent values (e.g., yes/Yes/Y → Yes) |
|
|
|
|
|
|
|
|
|
|
|
Do not drop any rows or columns. Your output must be a valid Python dict. |
|
|
|
|
|
|
|
|
|
|
|
Example: |
|
|
|
|
|
{{ |
|
|
|
|
|
"rename_columns": {{"dob": "date_of_birth"}}, |
|
|
|
|
|
"convert_types": {{"age": "int", "salary": "float", "signup_date": "date"}}, |
|
|
|
|
|
"fill_missing": {{"gender": "mode", "salary": -1}}, |
|
|
|
|
|
"value_map": {{ |
|
|
|
|
|
"gender": {{"M": "Male", "F": "Female"}}, |
|
|
|
|
|
"subscribed": {{"Y": "Yes", "N": "No"}} |
|
|
|
|
|
}} |
|
|
|
|
|
}} |
|
|
|
|
|
Apart from these mentioned steps, study the data and also do whatever things are good and needed for that particular dataset and do the cleaning. |
|
|
|
|
|
Sample data: |
|
|
|
|
|
{sample} |
|
|
|
|
|
""" |
|
|
|
|
|
raw_response = query_openai(prompt) |
|
|
|
|
|
try: |
|
|
|
|
|
suggestions = eval(raw_response) |
|
|
|
|
|
return suggestions |
|
|
|
|
|
except: |
|
|
|
|
|
print("Could not parse suggestions.") |
|
|
|
|
|
return { |
|
|
|
|
|
"rename_columns": {}, |
|
|
|
|
|
"convert_types": {}, |
|
|
|
|
|
"fill_missing": {}, |
|
|
|
|
|
"value_map": {} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def apply_suggestions(df, suggestions): |
|
|
|
|
|
df.rename(columns=suggestions.get("rename_columns", {}), inplace=True) |
|
|
|
|
|
|
|
|
|
|
|
for col, dtype in suggestions.get("convert_types", {}).items(): |
|
|
|
|
|
if col not in df.columns: |
|
|
|
|
|
continue |
|
|
|
|
|
try: |
|
|
|
|
|
if dtype == "int": |
|
|
|
|
|
df[col] = pd.to_numeric(df[col], errors='coerce').astype("Int64") |
|
|
|
|
|
elif dtype == "float": |
|
|
|
|
|
df[col] = pd.to_numeric(df[col], errors='coerce') |
|
|
|
|
|
elif dtype == "str": |
|
|
|
|
|
df[col] = df[col].astype(str) |
|
|
|
|
|
elif dtype == "date": |
|
|
|
|
|
df[col] = pd.to_datetime(df[col], errors='coerce') |
|
|
|
|
|
except: |
|
|
|
|
|
print(f"Failed to convert {col} to {dtype}") |
|
|
|
|
|
|
|
|
|
|
|
for col, method in suggestions.get("fill_missing", {}).items(): |
|
|
|
|
|
if col not in df.columns: |
|
|
|
|
|
continue |
|
|
|
|
|
try: |
|
|
|
|
|
if method == "mean": |
|
|
|
|
|
df[col].fillna(df[col].mean(), inplace=True) |
|
|
|
|
|
elif method == "median": |
|
|
|
|
|
df[col].fillna(df[col].median(), inplace=True) |
|
|
|
|
|
elif method == "mode": |
|
|
|
|
|
df[col].fillna(df[col].mode().iloc[0], inplace=True) |
|
|
|
|
|
elif isinstance(method, str): |
|
|
|
|
|
df[col].fillna(method, inplace=True) |
|
|
|
|
|
except: |
|
|
|
|
|
print(f"Could not fill missing values for {col}") |
|
|
|
|
|
|
|
|
|
|
|
for col, mapping in suggestions.get("value_map", {}).items(): |
|
|
|
|
|
if col in df.columns: |
|
|
|
|
|
try: |
|
|
|
|
|
df[col] = df[col].replace(mapping) |
|
|
|
|
|
except: |
|
|
|
|
|
print(f"Could not map values in {col}") |
|
|
|
|
|
|
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
|
|
|
def capture_dashboard_screenshot(): |
|
|
|
|
|
"""Capture the entire dashboard as a single image""" |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
import plotly.graph_objects as go |
|
|
|
|
|
from plotly.subplots import make_subplots |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig = make_subplots(rows=2, cols=2, |
|
|
|
|
|
subplot_titles=["Visualization 1", "Visualization 2", |
|
|
|
|
|
"Visualization 3", "Visualization 4"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, plot in enumerate(st.session_state.dashboard_plots): |
|
|
|
|
|
if plot is not None: |
|
|
|
|
|
row = (i // 2) + 1 |
|
|
|
|
|
col = (i % 2) + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for trace in plot.data: |
|
|
|
|
|
fig.add_trace(trace, row=row, col=col) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for axis_type in ['xaxis', 'yaxis']: |
|
|
|
|
|
axis_name = f"{axis_type}{i+1 if i > 0 else ''}" |
|
|
|
|
|
subplot_name = f"{axis_type}{row}{col}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(plot.layout, axis_name): |
|
|
|
|
|
axis_props = getattr(plot.layout, axis_name) |
|
|
|
|
|
fig.update_layout({subplot_name: axis_props}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
|
|
|
height=800, |
|
|
|
|
|
width=1000, |
|
|
|
|
|
title_text="Dashboard Overview", |
|
|
|
|
|
showlegend=False, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dashboard_path = f"{st.session_state.temp_dir}/dashboard_combined.png" |
|
|
|
|
|
fig.write_image(dashboard_path, scale=2) |
|
|
|
|
|
return dashboard_path |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
import traceback |
|
|
|
|
|
print(f"Error capturing dashboard: {str(e)}") |
|
|
|
|
|
print(traceback.format_exc()) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def generate_enhanced_pdf_report(): |
|
|
|
|
|
"""Generate an enhanced PDF report with proper handling of base64 image data""" |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
doc = SimpleDocTemplate(buffer, pagesize=letter, |
|
|
|
|
|
leftMargin=36, rightMargin=36, |
|
|
|
|
|
topMargin=36, bottomMargin=36) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
styles = getSampleStyleSheet() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
styles.add(ParagraphStyle( |
|
|
|
|
|
name='ReportTitle', |
|
|
|
|
|
parent=styles['Heading1'], |
|
|
|
|
|
fontSize=24, |
|
|
|
|
|
alignment=1, |
|
|
|
|
|
spaceAfter=20, |
|
|
|
|
|
textColor='#2C3E50' |
|
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
styles.add(ParagraphStyle( |
|
|
|
|
|
name='SectionHeader', |
|
|
|
|
|
parent=styles['Heading2'], |
|
|
|
|
|
fontSize=16, |
|
|
|
|
|
spaceBefore=15, |
|
|
|
|
|
spaceAfter=10, |
|
|
|
|
|
textColor='#2C3E50', |
|
|
|
|
|
borderWidth=1, |
|
|
|
|
|
borderColor='#95A5A6', |
|
|
|
|
|
borderPadding=5, |
|
|
|
|
|
borderRadius=5 |
|
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
styles.add(ParagraphStyle( |
|
|
|
|
|
name='SubHeader', |
|
|
|
|
|
parent=styles['Heading3'], |
|
|
|
|
|
fontSize=14, |
|
|
|
|
|
spaceBefore=10, |
|
|
|
|
|
spaceAfter=8, |
|
|
|
|
|
textColor='#34495E', |
|
|
|
|
|
fontWeight='bold' |
|
|
|
|
|
)) |
|
|
|
|
|
styles.add(ParagraphStyle( |
|
|
|
|
|
name='UserMessage', |
|
|
|
|
|
parent=styles['Normal'], |
|
|
|
|
|
fontSize=11, |
|
|
|
|
|
leftIndent=10, |
|
|
|
|
|
spaceBefore=8, |
|
|
|
|
|
spaceAfter=4 |
|
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
styles.add(ParagraphStyle( |
|
|
|
|
|
name='AssistantMessage', |
|
|
|
|
|
parent=styles['Normal'], |
|
|
|
|
|
fontSize=11, |
|
|
|
|
|
leftIndent=10, |
|
|
|
|
|
spaceBefore=4, |
|
|
|
|
|
spaceAfter=12, |
|
|
|
|
|
textColor='#2980B9' |
|
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
styles.add(ParagraphStyle( |
|
|
|
|
|
name='Timestamp', |
|
|
|
|
|
parent=styles['Italic'], |
|
|
|
|
|
fontSize=10, |
|
|
|
|
|
textColor='#7F8C8D', |
|
|
|
|
|
alignment=2 |
|
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elements = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elements.append(Paragraph('Data Analysis Report', styles['ReportTitle'])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elements.append(Paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}', |
|
|
|
|
|
styles['Timestamp'])) |
|
|
|
|
|
elements.append(Spacer(1, 0.5*inch)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elements.append(Paragraph('Analysis Conversation History', styles['SectionHeader'])) |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.chat_history: |
|
|
|
|
|
for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history): |
|
|
|
|
|
|
|
|
|
|
|
elements.append(Paragraph(f'<b>You:</b>', styles['SubHeader'])) |
|
|
|
|
|
user_msg_formatted = user_msg.replace('\n', '<br/>') |
|
|
|
|
|
elements.append(Paragraph(user_msg_formatted, styles['UserMessage'])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base64_pattern = r'!\[Visualization\]\(data:image\/png;base64,([^\)]+)\)' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if '### Visualizations' in assistant_msg or re.search(base64_pattern, assistant_msg): |
|
|
|
|
|
|
|
|
|
|
|
if '### Visualizations' in assistant_msg: |
|
|
|
|
|
parts = assistant_msg.split('### Visualizations', 1) |
|
|
|
|
|
text_part = parts[0] |
|
|
|
|
|
viz_part = "### Visualizations" + parts[1] if len(parts) > 1 else "" |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
match = re.search(base64_pattern, assistant_msg) |
|
|
|
|
|
text_part = assistant_msg[:match.start()] |
|
|
|
|
|
viz_part = assistant_msg[match.start():] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader'])) |
|
|
|
|
|
text_part = text_part.replace('\n', '<br/>') |
|
|
|
|
|
elements.append(Paragraph(text_part, styles['AssistantMessage'])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
matches = re.findall(base64_pattern, viz_part) |
|
|
|
|
|
for j, base64_data in enumerate(matches): |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
image_data = base64.b64decode(base64_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temp_img_path = f"{st.session_state.temp_dir}/chat_viz_{i}_{j}.png" |
|
|
|
|
|
|
|
|
|
|
|
with open(temp_img_path, 'wb') as f: |
|
|
|
|
|
f.write(image_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elements.append(Paragraph(f'<b>Visualization:</b>', styles['SubHeader'])) |
|
|
|
|
|
elements.append(Spacer(1, 0.1*inch)) |
|
|
|
|
|
img = Image(temp_img_path, width=6*inch, height=4*inch) |
|
|
|
|
|
elements.append(img) |
|
|
|
|
|
elements.append(Spacer(1, 0.2*inch)) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Error processing base64 image: {str(e)}") |
|
|
|
|
|
elements.append(Paragraph(f"[Error displaying visualization: {str(e)}]", |
|
|
|
|
|
styles['Normal'])) |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader'])) |
|
|
|
|
|
assistant_msg_formatted = assistant_msg.replace('\n', '<br/>') |
|
|
|
|
|
if len(assistant_msg_formatted) > 1500: |
|
|
|
|
|
assistant_msg_formatted = assistant_msg_formatted[:1500] + '...' |
|
|
|
|
|
elements.append(Paragraph(assistant_msg_formatted, styles['AssistantMessage'])) |
|
|
|
|
|
|
|
|
|
|
|
elements.append(Spacer(1, 0.2*inch)) |
|
|
|
|
|
else: |
|
|
|
|
|
elements.append(Paragraph('No conversation history available.', styles['Normal'])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elements.append(PageBreak()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elements.append(Paragraph('Dashboard Overview', styles['SectionHeader'])) |
|
|
|
|
|
elements.append(Spacer(1, 0.2*inch)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dashboard_img_path = capture_dashboard_screenshot() |
|
|
|
|
|
|
|
|
|
|
|
if dashboard_img_path: |
|
|
|
|
|
|
|
|
|
|
|
available_width = doc.width |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pil_img = PILImage.open(dashboard_img_path) |
|
|
|
|
|
img_width, img_height = pil_img.size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scale_factor = available_width / img_width |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_height = img_height * scale_factor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img = Image(dashboard_img_path, width=available_width, height=new_height) |
|
|
|
|
|
elements.append(img) |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
plot_count = 0 |
|
|
|
|
|
for i, plot in enumerate(st.session_state.dashboard_plots): |
|
|
|
|
|
if plot is not None: |
|
|
|
|
|
plot_count += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_bytes = io.BytesIO() |
|
|
|
|
|
plot.write_image(img_bytes, format='png', width=500, height=300) |
|
|
|
|
|
img_bytes.seek(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temp_img_path = f"{st.session_state.temp_dir}/plot_{i}.png" |
|
|
|
|
|
|
|
|
|
|
|
with open(temp_img_path, 'wb') as f: |
|
|
|
|
|
f.write(img_bytes.getvalue()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elements.append(Paragraph(f'Dashboard Visualization {i+1}', styles['SubHeader'])) |
|
|
|
|
|
elements.append(Spacer(1, 0.1*inch)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img = Image(temp_img_path, width=6.5*inch, height=4*inch) |
|
|
|
|
|
elements.append(img) |
|
|
|
|
|
elements.append(Spacer(1, 0.3*inch)) |
|
|
|
|
|
|
|
|
|
|
|
if plot_count == 0: |
|
|
|
|
|
elements.append(Paragraph('No visualizations have been added to the dashboard.', |
|
|
|
|
|
styles['Normal'])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
doc.build(elements) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pdf_value = buffer.getvalue() |
|
|
|
|
|
buffer.close() |
|
|
|
|
|
|
|
|
|
|
|
return pdf_value |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
import traceback |
|
|
|
|
|
print(f"Error generating enhanced PDF report: {str(e)}") |
|
|
|
|
|
print(traceback.format_exc()) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def chat_with_workflow(message, history, dataset_info): |
|
|
|
|
|
"""Send user query to the workflow and get response""" |
|
|
|
|
|
|
|
|
|
|
|
if not dataset_info: |
|
|
|
|
|
return "Please upload at least one dataset before asking questions." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not st.session_state.api_key: |
|
|
|
|
|
return "Please set up your API key and model in the Settings tab before chatting." |
|
|
|
|
|
|
|
|
|
|
|
print(f"Chat with workflow called with {len(dataset_info)} datasets") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
max_history = 3 |
|
|
|
|
|
previous_messages = [] |
|
|
|
|
|
|
|
|
|
|
|
if history: |
|
|
|
|
|
start_idx = max(0, len(history) - max_history) |
|
|
|
|
|
recent_history = history[start_idx:] |
|
|
|
|
|
|
|
|
|
|
|
for exchange in recent_history: |
|
|
|
|
|
if exchange[0]: |
|
|
|
|
|
previous_messages.append(HumanMessage(content=exchange[0])) |
|
|
|
|
|
if exchange[1]: |
|
|
|
|
|
previous_messages.append(AIMessage(content=exchange[1])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
state = AgentState( |
|
|
|
|
|
messages=previous_messages + [HumanMessage(content=message)], |
|
|
|
|
|
input_data=dataset_info, |
|
|
|
|
|
intermediate_outputs=[], |
|
|
|
|
|
current_variables=st.session_state.persistent_vars, |
|
|
|
|
|
output_image_paths=[] |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Executing workflow...") |
|
|
|
|
|
result = chain.invoke(state) |
|
|
|
|
|
print("Workflow execution completed") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
messages = result["messages"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response = "" |
|
|
|
|
|
if messages: |
|
|
|
|
|
latest_message = messages[-1] |
|
|
|
|
|
if hasattr(latest_message, "content"): |
|
|
|
|
|
content = latest_message.content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if message in content: |
|
|
|
|
|
content = content.split(message)[-1].strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
content_lines = content.split('\n') |
|
|
|
|
|
filtered_lines = [line for line in content_lines |
|
|
|
|
|
if not line.strip().startswith(("You:", "User:", "Human:", "Assistant:"))] |
|
|
|
|
|
content = '\n'.join(filtered_lines) |
|
|
|
|
|
|
|
|
|
|
|
response = content.strip() + "\n\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "output_image_paths" in result and result["output_image_paths"]: |
|
|
|
|
|
response += "### Visualizations\n\n" |
|
|
|
|
|
for img_path in result["output_image_paths"]: |
|
|
|
|
|
try: |
|
|
|
|
|
full_path = os.path.join(st.session_state.images_dir, img_path) |
|
|
|
|
|
with open(full_path, 'rb') as f: |
|
|
|
|
|
fig = pickle.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_bytes = BytesIO() |
|
|
|
|
|
fig.update_layout(width=800, height=500) |
|
|
|
|
|
pio.write_image(fig, img_bytes, format='png') |
|
|
|
|
|
img_bytes.seek(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
b64_img = base64.b64encode(img_bytes.read()).decode() |
|
|
|
|
|
response += f"\n\n" |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
response += f"Error loading visualization: {str(e)}\n\n" |
|
|
|
|
|
|
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
import traceback |
|
|
|
|
|
print(f"Error in chat_with_workflow: {str(e)}") |
|
|
|
|
|
print(traceback.format_exc()) |
|
|
|
|
|
return f"Error executing workflow: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
def auto_generate_dashboard(dataset_info): |
|
|
|
|
|
"""Generate an automatic dashboard with four plots""" |
|
|
|
|
|
|
|
|
|
|
|
if not dataset_info: |
|
|
|
|
|
return "Please upload a dataset first.", [None, None, None, None] |
|
|
|
|
|
|
|
|
|
|
|
prompt = """ |
|
|
|
|
|
You are a data visualization expert. Given a dataset, identify the top 4 most insightful plots using statistical reasoning or patterns (correlation, distribution, trends). |
|
|
|
|
|
|
|
|
|
|
|
Use plotly and store the plots in a list named plotly_figures. |
|
|
|
|
|
|
|
|
|
|
|
Include multivariate plots using color/size/facets when helpful. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
state = AgentState( |
|
|
|
|
|
messages=[HumanMessage(content=prompt)], |
|
|
|
|
|
input_data=dataset_info, |
|
|
|
|
|
intermediate_outputs=[], |
|
|
|
|
|
current_variables=st.session_state.persistent_vars, |
|
|
|
|
|
output_image_paths=[] |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
result = chain.invoke(state) |
|
|
|
|
|
figures = [] |
|
|
|
|
|
|
|
|
|
|
|
if "output_image_paths" in result: |
|
|
|
|
|
for img_path in result["output_image_paths"][:4]: |
|
|
|
|
|
try: |
|
|
|
|
|
full_path = os.path.join(st.session_state.images_dir, img_path) |
|
|
|
|
|
with open(full_path, 'rb') as f: |
|
|
|
|
|
fig = pickle.load(f) |
|
|
|
|
|
figures.append(fig) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Error loading figure: {e}") |
|
|
|
|
|
|
|
|
|
|
|
while len(figures) < 4: |
|
|
|
|
|
figures.append(None) |
|
|
|
|
|
|
|
|
|
|
|
st.session_state.dashboard_plots = figures |
|
|
|
|
|
return "Dashboard generated!", figures |
|
|
|
|
|
|
|
|
|
|
|
def generate_custom_plots_with_llm(dataset_info, x_col, y_col, facet_col): |
|
|
|
|
|
"""Generate custom plots based on user-selected columns""" |
|
|
|
|
|
|
|
|
|
|
|
if not dataset_info or not x_col or not y_col: |
|
|
|
|
|
return [None, None, None] |
|
|
|
|
|
|
|
|
|
|
|
prompt = f""" |
|
|
|
|
|
You are a data visualization expert. |
|
|
|
|
|
|
|
|
|
|
|
Create 3 insightful visualizations using Plotly based on: |
|
|
|
|
|
|
|
|
|
|
|
- X-axis: {x_col} |
|
|
|
|
|
- Y-axis: {y_col} |
|
|
|
|
|
- Facet (optional): {facet_col if facet_col != 'None' else 'None'} |
|
|
|
|
|
|
|
|
|
|
|
Try to find interesting relationships, trends, or clusters using appropriate chart types. |
|
|
|
|
|
|
|
|
|
|
|
Use `plotly_figures` list and avoid using fig.show(). |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
state = AgentState( |
|
|
|
|
|
messages=[HumanMessage(content=prompt)], |
|
|
|
|
|
input_data=dataset_info, |
|
|
|
|
|
intermediate_outputs=[], |
|
|
|
|
|
current_variables=st.session_state.persistent_vars, |
|
|
|
|
|
output_image_paths=[] |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
result = chain.invoke(state) |
|
|
|
|
|
figures = [] |
|
|
|
|
|
|
|
|
|
|
|
if "output_image_paths" in result: |
|
|
|
|
|
for img_path in result["output_image_paths"][:3]: |
|
|
|
|
|
try: |
|
|
|
|
|
full_path = os.path.join(st.session_state.images_dir, img_path) |
|
|
|
|
|
with open(full_path, 'rb') as f: |
|
|
|
|
|
fig = pickle.load(f) |
|
|
|
|
|
figures.append(fig) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Error loading figure: {e}") |
|
|
|
|
|
|
|
|
|
|
|
while len(figures) < 3: |
|
|
|
|
|
figures.append(None) |
|
|
|
|
|
return figures |
|
|
|
|
|
|
|
|
|
|
|
def remove_plot(index): |
|
|
|
|
|
"""Remove a plot from the dashboard""" |
|
|
|
|
|
if 0 <= index < len(st.session_state.dashboard_plots): |
|
|
|
|
|
st.session_state.dashboard_plots[index] = None |
|
|
|
|
|
|
|
|
|
|
|
def respond(message): |
|
|
|
|
|
"""Handle chat message response""" |
|
|
|
|
|
if not st.session_state.dataset_metadata_list: |
|
|
|
|
|
bot_message = "Please upload at least one dataset before asking questions." |
|
|
|
|
|
else: |
|
|
|
|
|
bot_message = chat_with_workflow(message, st.session_state.chat_history, st.session_state.dataset_metadata_list) |
|
|
|
|
|
|
|
|
|
|
|
st.session_state.chat_history.append((message, bot_message)) |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
def save_plot_to_dashboard(plot_index): |
|
|
|
|
|
"""Callback for the Add Plot button""" |
|
|
|
|
|
for i in range(len(st.session_state.dashboard_plots)): |
|
|
|
|
|
if st.session_state.dashboard_plots[i] is None: |
|
|
|
|
|
|
|
|
|
|
|
st.session_state.dashboard_plots[i] = st.session_state.custom_plots_to_save[plot_index] |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="QueryMind 🧠", layout="wide") |
|
|
|
|
|
st.title("QueryMind 🧠 - Data Assistant") |
|
|
|
|
|
st.markdown("Upload your datasets, ask questions, and generate visualizations to gain insights.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs(["Upload Datasets", "Data Cleaning", "Chat with AI Assistant", "Auto Dashboard Generator", "Generate Report", "Settings"]) |
|
|
|
|
|
|
|
|
|
|
|
with tab1: |
|
|
|
|
|
st.header("Upload Datasets") |
|
|
|
|
|
uploaded_files = st.file_uploader("Upload CSV or Excel Files", |
|
|
|
|
|
accept_multiple_files=True, |
|
|
|
|
|
type=['csv', 'xlsx', 'xls']) |
|
|
|
|
|
|
|
|
|
|
|
if uploaded_files and st.button("Process Uploaded Files"): |
|
|
|
|
|
with st.spinner("Processing files..."): |
|
|
|
|
|
preview, metadata_list, columns = process_file_upload(uploaded_files) |
|
|
|
|
|
st.session_state.columns = columns |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.success(f"✅ Successfully processed {len(uploaded_files)} file(s)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.subheader("Dataset Previews") |
|
|
|
|
|
|
|
|
|
|
|
for dataset_name, df in st.session_state.in_memory_datasets.items(): |
|
|
|
|
|
with st.expander(f"Preview: {dataset_name}"): |
|
|
|
|
|
|
|
|
|
|
|
st.write(f"**Rows:** {df.shape[0]} | **Columns:** {df.shape[1]}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
col_info = pd.DataFrame({ |
|
|
|
|
|
'Column Name': df.columns, |
|
|
|
|
|
'Data Type': df.dtypes.astype(str), |
|
|
|
|
|
'Non-Null Count': df.count().values, |
|
|
|
|
|
'Sample Values': [', '.join(df[col].dropna().astype(str).head(3).tolist()) for col in df.columns] |
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.write("**Column Information:**") |
|
|
|
|
|
st.dataframe(col_info, use_container_width=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.write("**Data Preview (First 10 rows):**") |
|
|
|
|
|
st.dataframe(df.head(10), use_container_width=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.info("👆 Click on the dataset names above to see detailed previews. Then proceed to the Data Cleaning tab to clean your data or Chat with AI Assistant to analyze it.") |
|
|
|
|
|
|
|
|
|
|
|
with tab2: |
|
|
|
|
|
st.header("Data Cleaning") |
|
|
|
|
|
|
|
|
|
|
|
if 'cleaning_done' not in st.session_state: |
|
|
|
|
|
st.session_state.cleaning_done = False |
|
|
|
|
|
|
|
|
|
|
|
if 'cleaned_datasets' not in st.session_state: |
|
|
|
|
|
st.session_state.cleaned_datasets = {} |
|
|
|
|
|
|
|
|
|
|
|
if 'cleaning_summaries' not in st.session_state: |
|
|
|
|
|
st.session_state.cleaning_summaries = {} |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.get("in_memory_datasets"): |
|
|
|
|
|
if not st.session_state.cleaning_done: |
|
|
|
|
|
if st.button("Run Data Cleaning"): |
|
|
|
|
|
with st.spinner("Running LLM-assisted cleaning..."): |
|
|
|
|
|
for name, df in st.session_state.in_memory_datasets.items(): |
|
|
|
|
|
raw_df = df.copy() |
|
|
|
|
|
df_std = standard_clean(raw_df.copy()) |
|
|
|
|
|
suggestions = llm_suggest_cleaning(df_std.copy()) |
|
|
|
|
|
df_clean = apply_suggestions(df_std.copy(), suggestions) |
|
|
|
|
|
st.session_state.cleaned_datasets[name] = df_clean |
|
|
|
|
|
st.session_state.cleaning_summaries[name] = suggestions |
|
|
|
|
|
st.session_state.cleaning_done = True |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
else: |
|
|
|
|
|
st.info("Click Run Data Cleaning to clean your datasets using the LLM.") |
|
|
|
|
|
else: |
|
|
|
|
|
for name, df_clean in st.session_state.cleaned_datasets.items(): |
|
|
|
|
|
raw_df = st.session_state.in_memory_datasets[name] |
|
|
|
|
|
|
|
|
|
|
|
st.subheader(f"Dataset: {name}") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
|
|
|
|
|
|
with col1: |
|
|
|
|
|
st.markdown("Original Data (First 5 Rows)") |
|
|
|
|
|
st.dataframe(raw_df.head()) |
|
|
|
|
|
|
|
|
|
|
|
with col2: |
|
|
|
|
|
st.markdown("Cleaned Data (First 5 Rows)") |
|
|
|
|
|
st.dataframe(df_clean.head()) |
|
|
|
|
|
|
|
|
|
|
|
st.markdown("Summary of Cleaning Actions") |
|
|
|
|
|
suggestions = st.session_state.cleaning_summaries[name] |
|
|
|
|
|
summary_text = "" |
|
|
|
|
|
|
|
|
|
|
|
if suggestions: |
|
|
|
|
|
for key, value in suggestions.items(): |
|
|
|
|
|
summary_text += f"**{key}**: {json.dumps(value, indent=2)}\n\n" |
|
|
|
|
|
st.markdown(summary_text) |
|
|
|
|
|
|
|
|
|
|
|
st.markdown("Refine the Cleaning (Natural Language Instructions)") |
|
|
|
|
|
user_input = st.text_input("Example: Convert 'dob' to datetime and fill missing with '2000-01-01'", |
|
|
|
|
|
key=f"user_input_{name}") |
|
|
|
|
|
|
|
|
|
|
|
if f'corrections_{name}' not in st.session_state: |
|
|
|
|
|
st.session_state[f'corrections_{name}'] = [] |
|
|
|
|
|
|
|
|
|
|
|
if st.button("Apply Correction", key=f'apply_correction_{name}'): |
|
|
|
|
|
if user_input.strip(): |
|
|
|
|
|
correction_prompt = f""" |
|
|
|
|
|
You are a data cleaning expert. Below is a previously cleaned dataset with these actions: |
|
|
|
|
|
|
|
|
|
|
|
{summary_text} |
|
|
|
|
|
|
|
|
|
|
|
The user now wants the following additional instruction: |
|
|
|
|
|
\"{user_input.strip()}\" |
|
|
|
|
|
|
|
|
|
|
|
Write only the Python code that modifies the pandas DataFrame `df` accordingly. |
|
|
|
|
|
Do not include explanations or markdown. |
|
|
|
|
|
""" |
|
|
|
|
|
correction_code = query_openai(correction_prompt) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
df = st.session_state.cleaned_datasets[name].copy() |
|
|
|
|
|
local_vars = {"df": df} |
|
|
|
|
|
exec(correction_code, {}, local_vars) |
|
|
|
|
|
df_updated = local_vars["df"] |
|
|
|
|
|
|
|
|
|
|
|
st.session_state.cleaned_datasets[name] = df_updated |
|
|
|
|
|
st.session_state[f'corrections_{name}'].append((user_input, correction_code)) |
|
|
|
|
|
st.success("Correction applied.") |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
st.error(f"Failed to apply correction: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state[f'corrections_{name}']: |
|
|
|
|
|
st.markdown("Applied Corrections") |
|
|
|
|
|
for i, (msg, code) in enumerate(st.session_state[f'corrections_{name}']): |
|
|
|
|
|
st.markdown(f"**Instruction:** {msg}") |
|
|
|
|
|
st.code(code, language='python') |
|
|
|
|
|
|
|
|
|
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
|
|
with col1: |
|
|
|
|
|
if st.button("Reset Cleaning and Re-run"): |
|
|
|
|
|
st.session_state.cleaning_done = False |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
with col2: |
|
|
|
|
|
if st.button("Finalize and Proceed to Visualizations"): |
|
|
|
|
|
st.session_state.cleaning_finalized = True |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
else: |
|
|
|
|
|
st.info("Please upload and process datasets first.") |
|
|
|
|
|
|
|
|
|
|
|
with tab3: |
|
|
|
|
|
st.header("Chat with AI Assistant") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not st.session_state.api_key: |
|
|
|
|
|
st.warning("⚠️ Please set up your API key and model in the Settings tab before using the chat.") |
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
|
|
|
## Example Questions |
|
|
|
|
|
- "What analysis can you perform on this dataset?" |
|
|
|
|
|
- "Show me basic statistics for all columns" |
|
|
|
|
|
- "Create a correlation heatmap" |
|
|
|
|
|
- "Plot the distribution of a specific column" |
|
|
|
|
|
- "What is the relationship between two columns?" |
|
|
|
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for exchange in st.session_state.chat_history: |
|
|
|
|
|
with st.chat_message("user"): |
|
|
|
|
|
st.write(exchange[0]) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
|
|
|
st.write(exchange[1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Your question"): |
|
|
|
|
|
with st.spinner("Thinking..."): |
|
|
|
|
|
respond(prompt) |
|
|
|
|
|
|
|
|
|
|
|
with tab4: |
|
|
|
|
|
st.header("Auto Dashboard Generator") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dashboard_title = st.text_input("Dashboard Title", placeholder="Enter your dashboard title") |
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
|
|
|
|
|
|
with col1: |
|
|
|
|
|
if st.button("Generate Suggested Dashboard (Auto)"): |
|
|
|
|
|
if not st.session_state.api_key: |
|
|
|
|
|
st.warning("⚠️ Please set up your API key and model in the Settings tab first.") |
|
|
|
|
|
else: |
|
|
|
|
|
with st.spinner("Generating dashboard..."): |
|
|
|
|
|
message, figures = auto_generate_dashboard(st.session_state.dataset_metadata_list) |
|
|
|
|
|
st.success(message) |
|
|
|
|
|
|
|
|
|
|
|
with col2: |
|
|
|
|
|
if st.button("Refresh Column Options"): |
|
|
|
|
|
st.session_state.columns = get_columns() |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.subheader("Dashboard") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
|
|
|
|
|
|
with col1: |
|
|
|
|
|
if st.session_state.dashboard_plots[0]: |
|
|
|
|
|
st.plotly_chart(st.session_state.dashboard_plots[0], use_container_width=True) |
|
|
|
|
|
if st.button("Remove Plot 1"): |
|
|
|
|
|
remove_plot(0) |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
with col2: |
|
|
|
|
|
if st.session_state.dashboard_plots[1]: |
|
|
|
|
|
st.plotly_chart(st.session_state.dashboard_plots[1], use_container_width=True) |
|
|
|
|
|
if st.button("Remove Plot 2"): |
|
|
|
|
|
remove_plot(1) |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
col3, col4 = st.columns(2) |
|
|
|
|
|
|
|
|
|
|
|
with col3: |
|
|
|
|
|
if st.session_state.dashboard_plots[2]: |
|
|
|
|
|
st.plotly_chart(st.session_state.dashboard_plots[2], use_container_width=True) |
|
|
|
|
|
if st.button("Remove Plot 3"): |
|
|
|
|
|
remove_plot(2) |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
with col4: |
|
|
|
|
|
if st.session_state.dashboard_plots[3]: |
|
|
|
|
|
st.plotly_chart(st.session_state.dashboard_plots[3], use_container_width=True) |
|
|
|
|
|
if st.button("Remove Plot 4"): |
|
|
|
|
|
remove_plot(3) |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.subheader("Custom Plot Generator") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
|
|
|
|
|
|
|
|
with col1: |
|
|
|
|
|
x_axis = st.selectbox("X-axis Column", options=st.session_state.columns) |
|
|
|
|
|
|
|
|
|
|
|
with col2: |
|
|
|
|
|
y_axis = st.selectbox("Y-axis Column", options=st.session_state.columns) |
|
|
|
|
|
|
|
|
|
|
|
with col3: |
|
|
|
|
|
facet = st.selectbox("Facet (optional)", options=["None"] + st.session_state.columns) |
|
|
|
|
|
|
|
|
if st.button("Generate Custom Visualizations"): |
|
|
|
|
|
if not st.session_state.api_key: |
|
|
|
|
|
st.warning("⚠️ Please set up your API key and model in the Settings tab first.") |
|
|
|
|
|
else: |
|
|
|
|
|
with st.spinner("Generating custom visualizations..."): |
|
|
|
|
|
custom_plots = generate_custom_plots_with_llm(st.session_state.dataset_metadata_list, x_axis, y_axis, facet) |
|
|
|
|
|
|
|
|
|
|
|
for i, plot in enumerate(custom_plots): |
|
|
|
|
|
if plot: |
|
|
|
|
|
st.session_state.custom_plots_to_save[i] = plot |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, plot in enumerate(custom_plots): |
|
|
|
|
|
if plot: |
|
|
|
|
|
st.plotly_chart(plot, use_container_width=True) |
|
|
|
|
|
st.button( |
|
|
|
|
|
f"Add Plot {i+1} to Dashboard", |
|
|
|
|
|
key=f"add_plot_{i}", |
|
|
|
|
|
on_click=save_plot_to_dashboard, |
|
|
|
|
|
args=(i,) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with tab5: |
|
|
|
|
|
st.header("Generate Analysis Report") |
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
|
|
|
Generate a PDF report containing: |
|
|
|
|
|
- Dashboard visualizations |
|
|
|
|
|
- Chat conversation history |
|
|
|
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
report_title = st.text_input("Report Title (Optional)", "Data Analysis Report") |
|
|
|
|
|
|
|
|
|
|
|
if st.button("Generate PDF Report"): |
|
|
|
|
|
if not st.session_state.api_key: |
|
|
|
|
|
st.warning("⚠️ Please set up your API key and model in the Settings tab first.") |
|
|
|
|
|
else: |
|
|
|
|
|
with st.spinner("Generating report..."): |
|
|
|
|
|
pdf_data = generate_enhanced_pdf_report() |
|
|
|
|
|
if pdf_data: |
|
|
|
|
|
|
|
|
|
|
|
b64_pdf = base64.b64encode(pdf_data).decode('utf-8') |
|
|
|
|
|
|
|
|
|
|
|
pdf_download_link = f'<a href="data:application/pdf;base64,{b64_pdf}" download="data_analysis_report.pdf">Download PDF Report</a>' |
|
|
|
|
|
st.markdown("### Your report is ready!") |
|
|
|
|
|
st.markdown(pdf_download_link, unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
with st.expander("Preview Report"): |
|
|
|
|
|
st.warning("PDF preview is not available in Streamlit, please download the report to view it.") |
|
|
|
|
|
else: |
|
|
|
|
|
st.error("Failed to generate the report. Please try again.") |
|
|
|
|
|
|
|
|
|
|
|
with tab6: |
|
|
|
|
|
st.header("AI Provider Settings") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
provider = st.radio("Select AI Provider", |
|
|
|
|
|
options=["OpenAI", "Groq"], |
|
|
|
|
|
index=0 if st.session_state.ai_provider == "openai" else 1, |
|
|
|
|
|
horizontal=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.session_state.ai_provider = provider.lower() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api_key = st.text_input("Enter API Key", |
|
|
|
|
|
value=st.session_state.api_key, |
|
|
|
|
|
type="password", |
|
|
|
|
|
help="Your API key for the selected provider") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.ai_provider == "openai": |
|
|
|
|
|
model_options = OPENAI_MODELS |
|
|
|
|
|
model_help = "GPT-4 provides the best results but is slower. GPT-3.5-Turbo is faster but less capable." |
|
|
|
|
|
else: |
|
|
|
|
|
model_options = GROQ_MODELS |
|
|
|
|
|
model_help = "Llama 3.3 70B is most capable. Gemma 2 9B offers good balance. Llama 3 8B is fastest." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
selected_model = st.selectbox("Select Model", |
|
|
|
|
|
options=model_options, |
|
|
|
|
|
index=model_options.index(st.session_state.selected_model) if st.session_state.selected_model in model_options else 0, |
|
|
|
|
|
help=model_help) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.button("Save Settings"): |
|
|
|
|
|
st.session_state.api_key = api_key |
|
|
|
|
|
st.session_state.selected_model = selected_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
test_llm = initialize_llm() |
|
|
|
|
|
if test_llm: |
|
|
|
|
|
st.success(f"✅ Successfully configured {provider} with model: {selected_model}") |
|
|
|
|
|
else: |
|
|
|
|
|
st.error("Failed to initialize the AI provider. Please check your API key and model selection.") |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
st.error(f"Error testing settings: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.subheader("Current Settings") |
|
|
|
|
|
settings_info = f""" |
|
|
|
|
|
- **Provider**: {st.session_state.ai_provider.upper()} |
|
|
|
|
|
- **Model**: {st.session_state.selected_model} |
|
|
|
|
|
- **API Key**: {'✅ Set' if st.session_state.api_key else '❌ Not Set'} |
|
|
|
|
|
""" |
|
|
|
|
|
st.markdown(settings_info) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.ai_provider == "openai": |
|
|
|
|
|
st.info(""" |
|
|
|
|
|
**OpenAI Models Information:** |
|
|
|
|
|
- **GPT-4**: Most powerful model, best for complex analysis and detailed explanations |
|
|
|
|
|
- **GPT-4-Turbo**: Faster than GPT-4 with similar capabilities |
|
|
|
|
|
- **GPT-4-Mini**: Economical option with good performance for standard tasks |
|
|
|
|
|
- **GPT-3.5-Turbo**: Fastest option, suitable for basic analysis and visualization |
|
|
|
|
|
""") |
|
|
|
|
|
else: |
|
|
|
|
|
st.info(""" |
|
|
|
|
|
**Groq Models Information:** |
|
|
|
|
|
- **llama3.3-70b-versatile**: Most powerful model for comprehensive analysis |
|
|
|
|
|
- **gemma2-9b-it**: Good balance of speed and capabilities |
|
|
|
|
|
- **llama-3-8b-8192**: Fastest option for basic analysis tasks |
|
|
|
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.expander("How to get API Keys"): |
|
|
|
|
|
if st.session_state.ai_provider == "openai": |
|
|
|
|
|
st.markdown(""" |
|
|
|
|
|
### Getting an OpenAI API Key |
|
|
|
|
|
|
|
|
|
|
|
1. Go to [OpenAI's platform](https://platform.openai.com) |
|
|
|
|
|
2. Sign up or log in to your account |
|
|
|
|
|
3. Navigate to the API section |
|
|
|
|
|
4. Create a new API key |
|
|
|
|
|
5. Copy the key and paste it above |
|
|
|
|
|
|
|
|
|
|
|
Note: OpenAI API usage incurs charges based on tokens used. |
|
|
|
|
|
""") |
|
|
|
|
|
else: |
|
|
|
|
|
st.markdown(""" |
|
|
|
|
|
### Getting a Groq API Key |
|
|
|
|
|
|
|
|
|
|
|
1. Go to [Groq's website](https://console.groq.com/keys) |
|
|
|
|
|
2. Sign up or log in to your account |
|
|
|
|
|
3. Navigate to API Keys section |
|
|
|
|
|
4. Create a new API key |
|
|
|
|
|
5. Copy the key and paste it above |
|
|
|
|
|
|
|
|
|
|
|
Note: Check Groq's pricing page for current rates. |
|
|
|
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cleanup(): |
|
|
|
|
|
try: |
|
|
|
|
|
shutil.rmtree(st.session_state.temp_dir) |
|
|
|
|
|
print(f"Cleaned up temporary directory: {st.session_state.temp_dir}") |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Error cleaning up: {e}") |
|
|
|
|
|
|
|
|
|
|
|
import atexit |
|
|
|
|
|
atexit.register(cleanup) |