Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -156,24 +156,86 @@ def create_vector_store(df_text):
|
|
156 |
os.unlink(temp_path)
|
157 |
return vector_store
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
|
160 |
-
"""Get response from Groq with vector store context"""
|
161 |
system_prompt = (
|
162 |
"You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
|
163 |
f"The user is on the '{app_mode}' page:\n"
|
164 |
"- **Data Upload**: Upload CSV/XLSX files, view stats, or generate reports.\n"
|
165 |
"- **Data Cleaning**: Clean data (e.g., handle missing values, encode variables).\n"
|
166 |
-
"- **EDA**: Visualize data (e.g., scatter plots, histograms).\n"
|
|
|
167 |
)
|
168 |
|
169 |
context = ""
|
170 |
if vector_store:
|
171 |
docs = vector_store.similarity_search(user_input, k=3)
|
172 |
if docs:
|
173 |
-
context = "\n\nDataset Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
|
174 |
-
system_prompt += f"Use this dataset context to augment your response:\n{context}"
|
175 |
else:
|
176 |
-
system_prompt += "No dataset is loaded. Assist based on app functionality."
|
177 |
|
178 |
try:
|
179 |
response = client.chat.completions.create(
|
@@ -230,20 +292,8 @@ def analyze_plot():
|
|
230 |
return "No plot available to analyze."
|
231 |
plot_info = st.session_state.last_plot
|
232 |
df = pd.read_json(plot_info["data"])
|
233 |
-
|
234 |
-
|
235 |
-
y_col = plot_info["y"] if "y" in plot_info else None
|
236 |
-
|
237 |
-
if plot_type == "Scatter Plot" and y_col:
|
238 |
-
correlation = df[x_col].corr(df[y_col])
|
239 |
-
strength = "strong" if abs(correlation) > 0.7 else "moderate" if abs(correlation) > 0.3 else "weak"
|
240 |
-
direction = "positive" if correlation > 0 else "negative"
|
241 |
-
return f"The scatter plot of {x_col} vs {y_col} shows a {strength} {direction} correlation (Pearson r = {correlation:.2f})."
|
242 |
-
elif plot_type == "Histogram":
|
243 |
-
skewness = df[x_col].skew()
|
244 |
-
skew_desc = "positively skewed" if skewness > 1 else "negatively skewed" if skewness < -1 else "approximately symmetric"
|
245 |
-
return f"The histogram of {x_col} is {skew_desc} (skewness = {skewness:.2f})."
|
246 |
-
return "Inference not available for this plot type."
|
247 |
|
248 |
def parse_command(command):
|
249 |
command = command.lower().strip()
|
@@ -529,6 +579,11 @@ def main():
|
|
529 |
"y": y_axis,
|
530 |
"data": df[[x_axis, y_axis]].to_json() if y_axis else df[[x_axis]].to_json()
|
531 |
}
|
|
|
|
|
|
|
|
|
|
|
532 |
else:
|
533 |
st.error("Please provide required inputs for the selected plot type.")
|
534 |
except Exception as e:
|
|
|
156 |
os.unlink(temp_path)
|
157 |
return vector_store
|
158 |
|
159 |
+
def update_vector_store_with_plot(plot_text, existing_vector_store):
|
160 |
+
"""Update the FAISS vector store with plot-derived text"""
|
161 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file:
|
162 |
+
temp_file.write(plot_text)
|
163 |
+
temp_path = temp_file.name
|
164 |
+
|
165 |
+
loader = TextLoader(temp_path)
|
166 |
+
documents = loader.load()
|
167 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
168 |
+
texts = text_splitter.split_documents(documents)
|
169 |
+
|
170 |
+
if existing_vector_store:
|
171 |
+
existing_vector_store.add_documents(texts)
|
172 |
+
else:
|
173 |
+
existing_vector_store = FAISS.from_documents(texts, embeddings)
|
174 |
+
|
175 |
+
os.unlink(temp_path)
|
176 |
+
return existing_vector_store
|
177 |
+
|
178 |
+
def extract_plot_data(plot_info, df):
|
179 |
+
"""Extract numerical data from the last generated plot and convert to text"""
|
180 |
+
plot_type = plot_info["type"]
|
181 |
+
x_col = plot_info["x"]
|
182 |
+
y_col = plot_info["y"] if "y" in plot_info else None
|
183 |
+
data = pd.read_json(plot_info["data"])
|
184 |
+
|
185 |
+
plot_text = f"Plot Type: {plot_type}\n"
|
186 |
+
plot_text += f"X-Axis: {x_col}\n"
|
187 |
+
if y_col:
|
188 |
+
plot_text += f"Y-Axis: {y_col}\n"
|
189 |
+
|
190 |
+
if plot_type == "Scatter Plot" and y_col:
|
191 |
+
correlation = data[x_col].corr(data[y_col])
|
192 |
+
slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col].dropna(), data[y_col].dropna())
|
193 |
+
plot_text += f"Correlation: {correlation:.2f}\n"
|
194 |
+
plot_text += f"Linear Regression: Slope={slope:.2f}, Intercept={intercept:.2f}, R²={r_value**2:.2f}, p-value={p_value:.4f}\n"
|
195 |
+
plot_text += f"X Stats: Mean={data[x_col].mean():.2f}, Std={data[x_col].std():.2f}, Min={data[x_col].min():.2f}, Max={data[x_col].max():.2f}\n"
|
196 |
+
plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Min={data[y_col].min():.2f}, Max={data[y_col].max():.2f}\n"
|
197 |
+
elif plot_type == "Histogram":
|
198 |
+
plot_text += f"Stats: Mean={data[x_col].mean():.2f}, Median={data[x_col].median():.2f}, Std={data[x_col].std():.2f}\n"
|
199 |
+
plot_text += f"Skewness: {data[x_col].skew():.2f}\n"
|
200 |
+
plot_text += f"Range: [{data[x_col].min():.2f}, {data[x_col].max():.2f}]\n"
|
201 |
+
elif plot_type == "Box Plot" and y_col:
|
202 |
+
q1, q3 = data[y_col].quantile(0.25), data[y_col].quantile(0.75)
|
203 |
+
iqr = q3 - q1
|
204 |
+
plot_text += f"Y Stats: Median={data[y_col].median():.2f}, Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}\n"
|
205 |
+
plot_text += f"Outliers: {len(data[y_col][(data[y_col] < q1 - 1.5 * iqr) | (data[y_col] > q3 + 1.5 * iqr)])} potential outliers\n"
|
206 |
+
elif plot_type == "Line Chart" and y_col:
|
207 |
+
plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Trend={'increasing' if data[y_col].iloc[-1] > data[y_col].iloc[0] else 'decreasing'}\n"
|
208 |
+
elif plot_type == "Bar Chart":
|
209 |
+
plot_text += f"Counts: {data[x_col].value_counts().to_dict()}\n"
|
210 |
+
elif plot_type == "Correlation Matrix":
|
211 |
+
corr = data.corr()
|
212 |
+
plot_text += "Correlation Matrix:\n"
|
213 |
+
for col1 in corr.columns:
|
214 |
+
for col2 in corr.index:
|
215 |
+
if col1 < col2: # Avoid duplicates
|
216 |
+
plot_text += f"{col1} vs {col2}: {corr.loc[col2, col1]:.2f}\n"
|
217 |
+
|
218 |
+
return plot_text
|
219 |
+
|
220 |
def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
|
221 |
+
"""Get response from Groq with vector store context including plot data"""
|
222 |
system_prompt = (
|
223 |
"You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
|
224 |
f"The user is on the '{app_mode}' page:\n"
|
225 |
"- **Data Upload**: Upload CSV/XLSX files, view stats, or generate reports.\n"
|
226 |
"- **Data Cleaning**: Clean data (e.g., handle missing values, encode variables).\n"
|
227 |
+
"- **EDA**: Visualize data (e.g., scatter plots, histograms) and analyze plots.\n"
|
228 |
+
"When analyzing plots, provide detailed insights based on numerical data extracted from them."
|
229 |
)
|
230 |
|
231 |
context = ""
|
232 |
if vector_store:
|
233 |
docs = vector_store.similarity_search(user_input, k=3)
|
234 |
if docs:
|
235 |
+
context = "\n\nDataset and Plot Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
|
236 |
+
system_prompt += f"Use this dataset and plot context to augment your response:\n{context}"
|
237 |
else:
|
238 |
+
system_prompt += "No dataset or plot data is loaded. Assist based on app functionality."
|
239 |
|
240 |
try:
|
241 |
response = client.chat.completions.create(
|
|
|
292 |
return "No plot available to analyze."
|
293 |
plot_info = st.session_state.last_plot
|
294 |
df = pd.read_json(plot_info["data"])
|
295 |
+
plot_text = extract_plot_data(plot_info, df)
|
296 |
+
return f"Analysis of the last plot:\n{plot_text}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
def parse_command(command):
|
299 |
command = command.lower().strip()
|
|
|
579 |
"y": y_axis,
|
580 |
"data": df[[x_axis, y_axis]].to_json() if y_axis else df[[x_axis]].to_json()
|
581 |
}
|
582 |
+
# Extract numerical data and update vector store
|
583 |
+
plot_text = extract_plot_data(st.session_state.last_plot, df)
|
584 |
+
st.session_state.vector_store = update_vector_store_with_plot(plot_text, st.session_state.vector_store)
|
585 |
+
with st.expander("Extracted Plot Data"):
|
586 |
+
st.text(plot_text)
|
587 |
else:
|
588 |
st.error("Please provide required inputs for the selected plot type.")
|
589 |
except Exception as e:
|