CosmickVisions commited on
Commit
925e1b1
·
verified ·
1 Parent(s): eeb6964

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -73
app.py CHANGED
@@ -234,6 +234,7 @@ def update_vector_store_with_plot(plot_text, existing_vector_store):
234
  return existing_vector_store
235
 
236
  def extract_plot_data(plot_info, df):
 
237
  plot_type = plot_info["type"]
238
  x_col = plot_info["x"]
239
  y_col = plot_info["y"] if "y" in plot_info else None
@@ -271,6 +272,16 @@ def extract_plot_data(plot_info, df):
271
  plot_text += f"{col1} vs {col2}: {corr.loc[col2, col1]:.2f}\n"
272
  return plot_text
273
 
 
 
 
 
 
 
 
 
 
 
274
  def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
275
  system_prompt = (
276
  "You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
@@ -368,6 +379,34 @@ def display_dataset_preview():
368
  st.dataframe(st.session_state.cleaned_data.head(10), use_container_width=True)
369
  st.markdown("---")
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  # Main App
372
  def main():
373
  # Header
@@ -562,84 +601,89 @@ def main():
562
  new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
563
  update_cleaned_data(new_df)
564
 
565
- elif app_mode == "EDA":
566
- st.header("🔍 Interactive Data Explorer")
567
  if 'cleaned_data' not in st.session_state:
568
  st.warning("Please upload and clean data first.")
569
  st.stop()
570
  df = st.session_state.cleaned_data.copy()
571
 
572
- enhance_section_title("Dataset Overview")
573
- with st.container():
574
- col1, col2, col3, col4 = st.columns(4)
575
- col1.metric("Total Rows", df.shape[0])
576
- col2.metric("Total Columns", df.shape[1])
577
- missing_percentage = df.isna().sum().sum() / df.size * 100
578
- col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
579
- col4.metric("Duplicates", df.duplicated().sum())
580
-
581
- tab1, tab2, tab3 = st.tabs(["Quick Preview", "Column Types", "Missing Matrix"])
582
- with tab1:
583
- st.write("First few rows of the dataset:")
584
- st.dataframe(df.head(), use_container_width=True)
585
- with tab2:
586
- st.write("Column Data Types:")
587
- type_counts = df.dtypes.value_counts().reset_index()
588
- type_counts.columns = ['Type', 'Count']
589
- st.dataframe(type_counts, use_container_width=True)
590
- with tab3:
591
- st.write("Missing Values Matrix:")
592
- fig_missing = px.imshow(df.isna(), color_continuous_scale=['#e0e0e0', '#66c2a5'])
593
- fig_missing.update_layout(coloraxis_colorscale=[[0, 'lightgrey'], [1, '#FF4B4B']])
594
- st.plotly_chart(fig_missing, use_container_width=True)
595
-
596
- enhance_section_title("Interactive Visualization Builder")
597
- with st.container():
598
- col1, col2 = st.columns([1, 3])
599
- with col1:
600
- plot_type = st.selectbox("Choose visualization type", [
601
- "Scatter Plot", "Histogram", "Box Plot", "Line Chart", "Bar Chart", "Correlation Matrix"
602
- ])
603
- x_axis = st.selectbox("X-axis", df.columns) if plot_type != "Correlation Matrix" else None
604
- y_axis = st.selectbox("Y-axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Line Chart"] else None
605
- color_by = st.selectbox("Color encoding", ["None"] + df.columns.tolist(), format_func=lambda x: "No color" if x == "None" else x) if plot_type != "Correlation Matrix" else None
606
-
607
- with col2:
608
- try:
609
- fig = None
610
- if plot_type == "Scatter Plot" and x_axis and y_axis:
611
- fig = px.scatter(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Scatter Plot of {x_axis} vs {y_axis}')
612
- elif plot_type == "Histogram" and x_axis:
613
- fig = px.histogram(df, x=x_axis, color=color_by if color_by != "None" else None, nbins=30, title=f'Histogram of {x_axis}')
614
- elif plot_type == "Box Plot" and x_axis and y_axis:
615
- fig = px.box(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Box Plot of {x_axis} vs {y_axis}')
616
- elif plot_type == "Line Chart" and x_axis and y_axis:
617
- fig = px.line(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Line Chart of {x_axis} vs {y_axis}')
618
- elif plot_type == "Bar Chart" and x_axis:
619
- fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Bar Chart of {x_axis}')
620
- elif plot_type == "Correlation Matrix":
621
- numeric_df = df.select_dtypes(include=np.number)
622
- if len(numeric_df.columns) > 1:
623
- corr = numeric_df.corr()
624
- fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r', zmin=-1, zmax=1, title='Correlation Matrix')
625
-
626
- if fig:
627
- fig.update_layout(template="plotly_white")
628
- st.plotly_chart(fig, use_container_width=True)
629
- st.session_state.last_plot = {
630
- "type": plot_type,
631
- "x": x_axis,
632
- "y": y_axis,
633
- "data": df[[x_axis, y_axis]].to_json() if y_axis else df[[x_axis]].to_json()
634
- }
635
- plot_text = extract_plot_data(st.session_state.last_plot, df)
636
- st.session_state.vector_store = update_vector_store_with_plot(plot_text, st.session_state.vector_store)
637
- with st.expander("Extracted Plot Data"):
638
- st.text(plot_text)
639
- else:
640
- st.error("Please provide required inputs for the selected plot type.")
641
- except Exception as e:
642
- st.error(f"Couldn't create visualization: {str(e)}")
 
 
 
 
 
643
 
644
  # Chatbot Section
645
  st.markdown("---")
 
234
  return existing_vector_store
235
 
236
  def extract_plot_data(plot_info, df):
237
+ # Updated to handle Plotly.js JSON
238
  plot_type = plot_info["type"]
239
  x_col = plot_info["x"]
240
  y_col = plot_info["y"] if "y" in plot_info else None
 
272
  plot_text += f"{col1} vs {col2}: {corr.loc[col2, col1]:.2f}\n"
273
  return plot_text
274
 
275
+ def generate_3d_scatter_plot(params):
276
+ df = st.session_state.cleaned_data
277
+ match = re.search(r"([\w\s]+)\s+vs\s+([\w\s]+)\s+vs\s+([\w\s]+)", params)
278
+ if match and len(match.groups()) >= 3:
279
+ x_axis, y_axis, z_axis = match.group(1).strip(), match.group(2).strip(), match.group(3).strip()
280
+ if x_axis in df.columns and y_axis in df.columns and z_axis in df.columns:
281
+ fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, title=f'3D Scatter Plot of {x_axis} vs {y_axis} vs {z_axis}')
282
+ return fig.to_json()
283
+ return None
284
+
285
  def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
286
  system_prompt = (
287
  "You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
 
379
  st.dataframe(st.session_state.cleaned_data.head(10), use_container_width=True)
380
  st.markdown("---")
381
 
382
+ def suggest_data_cleaning(df):
383
+ suggestions = []
384
+ if df.isna().sum().sum() > 0:
385
+ for col in df.columns:
386
+ na_count = df[col].isna().sum()
387
+ if na_count > 0:
388
+ if na_count / df.shape[0] > 0.5:
389
+ suggestions.append(f"- Drop column '{col}' (>{50}% missing values)")
390
+ else:
391
+ suggestions.append(f"- Impute missing values in column '{col}' ({na_count} missing values)")
392
+ return "\n".join(suggestions) if suggestions else "No automatic cleaning suggestions."
393
+
394
+ def parse_command(command):
395
+ # ... (Previous command parser) ...
396
+ elif "show a 3d scatter plot" in command or "3d scatter plot of" in command:
397
+ params = command.replace("show a 3d scatter plot of", "").replace("3d scatter plot of", "").strip()
398
+ return generate_3d_scatter_plot, params
399
+ # ... (rest of the function is same)
400
+
401
+ def parse_multistep_command(command):
402
+ steps = command.split(';')
403
+ parsed_steps = []
404
+ for step in steps:
405
+ func, param = parse_command(step.strip())
406
+ if func:
407
+ parsed_steps.append((func, param))
408
+ return parsed_steps
409
+
410
  # Main App
411
  def main():
412
  # Header
 
601
  new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
602
  update_cleaned_data(new_df)
603
 
604
+ elif app_mode == "EDA":
605
+ st.header("🔍 Exploratory Data Analysis (EDA)")
606
  if 'cleaned_data' not in st.session_state:
607
  st.warning("Please upload and clean data first.")
608
  st.stop()
609
  df = st.session_state.cleaned_data.copy()
610
 
611
+ st.markdown("### Dataset Overview")
612
+ col1, col2, col3 = st.columns(3)
613
+ col1.metric("Rows", df.shape[0])
614
+ col2.metric("Columns", df.shape[1])
615
+ col3.metric("Missing Values", df.isna().sum().sum())
616
+
617
+ # Interactive Visualization Builder with Plotly.js
618
+ st.markdown("### Interactive Visualization Builder")
619
+ plot_type = st.selectbox("Choose visualization type", [
620
+ "Scatter Plot", "Histogram", "Box Plot", "Line Chart", "Bar Chart", "Correlation Matrix", "3D Scatter Plot"
621
+ ])
622
+ x_axis = st.selectbox("X-axis", df.columns) if plot_type != "Correlation Matrix" else None
623
+ y_axis = st.selectbox("Y-axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Line Chart", "3D Scatter Plot"] else None
624
+ z_axis = st.selectbox("Z-axis", df.columns) if plot_type == "3D Scatter Plot" else None
625
+
626
+ generate_plot = st.button("Generate Plot")
627
+ if generate_plot:
628
+ fig_json = None
629
+ try:
630
+ if plot_type == "Scatter Plot":
631
+ fig = px.scatter(df, x=x_axis, y=y_axis, title=f'Scatter Plot of {x_axis} vs {y_axis}')
632
+ fig_json = fig.to_json()
633
+
634
+ elif plot_type == "Histogram":
635
+ fig = px.histogram(df, x=x_axis, title=f'Histogram of {x_axis}')
636
+ fig_json = fig.to_json()
637
+
638
+ elif plot_type == "Box Plot":
639
+ fig = px.box(df, x=x_axis, y=y_axis, title=f'Box Plot of {x_axis} vs {y_axis}')
640
+ fig_json = fig.to_json()
641
+
642
+ elif plot_type == "Line Chart":
643
+ fig = px.line(df, x=x_axis, y=y_axis, title=f'Line Chart of {x_axis} vs {y_axis}')
644
+ fig_json = fig.to_json()
645
+
646
+ elif plot_type == "Bar Chart":
647
+ fig = px.bar(df, x=x_axis, title=f'Bar Chart of {x_axis}')
648
+ fig_json = fig.to_json()
649
+
650
+ elif plot_type == "Correlation Matrix":
651
+ numeric_df = df.select_dtypes(include=np.number)
652
+ if len(numeric_df.columns) > 1:
653
+ corr = numeric_df.corr()
654
+ fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r', zmin=-1, zmax=1, title='Correlation Matrix')
655
+ fig_json = fig.to_json()
656
+
657
+ elif plot_type == "3D Scatter Plot":
658
+ fig_json = generate_3d_scatter_plot(f"{x_axis} vs {y_axis} vs {z_axis}")
659
+
660
+ if fig_json:
661
+ # Render Plotly.js Chart
662
+ st.components.v1.html(f"""
663
+ <div id="plotly-chart"></div>
664
+ <script>
665
+ Plotly.newPlot('plotly-chart', {fig_json});
666
+ </script>
667
+ """, height=600)
668
+
669
+ # Store Plotly JSON in session state
670
+ st.session_state.last_plot = {
671
+ "type": plot_type,
672
+ "x": x_axis,
673
+ "y": y_axis,
674
+ "z": z_axis if plot_type == "3D Scatter Plot" else None,
675
+ "data": fig_json
676
+ }
677
+
678
+ # Extract and display plot data
679
+ plot_text = extract_plot_data(st.session_state.last_plot, df)
680
+ st.session_state.vector_store = update_vector_store_with_plot(plot_text, st.session_state.vector_store)
681
+ with st.expander("Extracted Plot Data"):
682
+ st.text(plot_text)
683
+
684
+ except Exception as e:
685
+ st.error(f"Couldn't generate plot: {str(e)}")
686
+
687
 
688
  # Chatbot Section
689
  st.markdown("---")