Mattral commited on
Commit
976563d
·
verified ·
1 Parent(s): af43b8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -2
app.py CHANGED
@@ -5,6 +5,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
5
  from sklearn.metrics.pairwise import cosine_similarity
6
  from Levenshtein import distance as levenshtein_distance
7
  import matplotlib.pyplot as plt
 
8
 
9
 
10
  ms = st.session_state
@@ -108,6 +109,20 @@ def plot_correlation(df, column):
108
  return plt.gcf() # Return the matplotlib figure
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def main():
112
  st.title("Item Comparison App")
113
 
@@ -173,13 +188,19 @@ def main():
173
  # Show correlation plot for each dataset
174
  if st.button("Correlation for each dataset"):
175
 
176
- st.subheader("Correlation Plot for Warehouse Dataset")
177
  warehouse_corr_plot = plot_correlation(warehouse_df, warehouse_column)
178
  st.pyplot(warehouse_corr_plot)
179
 
180
- st.subheader("Correlation Plot for Industry Dataset")
181
  industry_corr_plot = plot_correlation(industry_df, industry_column)
182
  st.pyplot(industry_corr_plot)
183
 
 
 
 
 
 
 
184
  if __name__ == "__main__":
185
  main()
 
5
  from sklearn.metrics.pairwise import cosine_similarity
6
  from Levenshtein import distance as levenshtein_distance
7
  import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
 
10
 
11
  ms = st.session_state
 
109
  return plt.gcf() # Return the matplotlib figure
110
 
111
 
112
+ def plot_correlation_matrix(df):
113
+ # Filter for numeric columns, if the DataFrame has non-numeric columns
114
+ numeric_df = df.select_dtypes(include=['number'])
115
+ correlation_matrix = numeric_df.corr()
116
+
117
+ # Plotting the heatmap
118
+ plt.figure(figsize=(10, 8))
119
+ sns.heatmap(correlation_matrix, annot=True, fmt=".2f", cmap='coolwarm', cbar=True, linewidths=0.5)
120
+ plt.title("Correlation Matrix")
121
+ plt.xticks(rotation=45, ha="right")
122
+ plt.yticks(rotation=0)
123
+ plt.tight_layout() # Adjusts plot to ensure everything fits without overlap
124
+ st.pyplot() # Use Streamlit's method to display the plot
125
+
126
  def main():
127
  st.title("Item Comparison App")
128
 
 
188
  # Show correlation plot for each dataset
189
  if st.button("Correlation for each dataset"):
190
 
191
+ st.subheader("Correlation Plot for 1st Dataset")
192
  warehouse_corr_plot = plot_correlation(warehouse_df, warehouse_column)
193
  st.pyplot(warehouse_corr_plot)
194
 
195
+ st.subheader("Correlation Plot for 2nd Dataset")
196
  industry_corr_plot = plot_correlation(industry_df, industry_column)
197
  st.pyplot(industry_corr_plot)
198
 
199
+ st.subheader("Correlation Matrix for 1st Dataset")
200
+ plot_correlation_matrix(warehouse_df)
201
+
202
+ st.subheader("Correlation Matrix for 2nd Dataset")
203
+ plot_correlation_matrix(industry_df)
204
+
205
  if __name__ == "__main__":
206
  main()