wi-lab commited on
Commit
f23306e
·
verified ·
1 Parent(s): cbea7da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -53
app.py CHANGED
@@ -61,36 +61,44 @@ def compute_f1_score(cm):
61
  f1 = np.nan_to_num(f1) # Replace NaN with 0
62
  return np.mean(f1) # Return the mean F1-score across all classes
63
 
64
- def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
65
  # Compute the average F1-score
66
  avg_f1 = compute_f1_score(cm)
67
 
68
- # Set dark mode styling
69
- plt.style.use('dark_background')
 
 
 
 
 
 
 
 
70
  plt.figure(figsize=(10, 10))
71
-
72
  # Plot the confusion matrix with a dark-mode compatible colormap
73
- #sns.heatmap(cm, cmap="magma", cbar=True, linecolor='white', vmin=0, vmax=cm.max(), alpha=0.85)
74
- sns.heatmap(cm, cmap="cividis", cbar=True, linecolor='white', vmin=0, vmax=cm.max(), alpha=0.85)
75
-
76
  # Add F1-score to the title
77
- plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color="white", fontsize=14)
78
-
79
  tick_marks = np.arange(len(classes))
80
- plt.xticks(tick_marks, classes, color="white", fontsize=14) # White text for dark mode
81
- plt.yticks(tick_marks, classes, color="white", fontsize=14) # White text for dark mode
82
 
83
- plt.ylabel('True label', color="white", fontsize=14)
84
- plt.xlabel('Predicted label', color="white", fontsize=14)
85
  plt.tight_layout()
86
-
87
  # Save the plot as an image
88
- plt.savefig(save_path, transparent=True) # Use transparent to blend with the dark mode website
89
  plt.close()
90
-
91
  # Return the saved image
92
  return Image.open(save_path)
93
 
 
94
  def compute_average_confusion_matrix(folder):
95
  confusion_matrices = []
96
  max_num_labels = 0
@@ -140,7 +148,7 @@ from sklearn.metrics import f1_score
140
  import seaborn as sns
141
 
142
  # Function to compute confusion matrix, F1-score and plot it with dark mode style
143
- def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
144
  # Load CSV file
145
  data = pd.read_csv(csv_file_path)
146
 
@@ -153,29 +161,37 @@ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
153
 
154
  # Compute F1-score
155
  f1 = f1_score(y_true, y_pred, average='macro') # Macro-average F1-score
 
 
 
 
 
 
 
 
 
 
156
 
157
- # Set dark mode styling
158
- plt.style.use('dark_background')
159
  plt.figure(figsize=(5, 5))
160
-
161
- # Plot the confusion matrix with a dark-mode compatible colormap
162
- sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
163
-
164
  # Add F1-score to the title
165
- plt.title(f"{title}\n(F1 Score: {f1:.3f})", color="white", fontsize=14)
166
-
167
- # Customize tick labels for dark mode
168
- plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
169
- plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
170
-
171
- plt.ylabel('True label', color="white", fontsize=12)
172
- plt.xlabel('Predicted label', color="white", fontsize=12)
173
  plt.tight_layout()
174
-
175
  # Save the plot as an image
176
- plt.savefig(save_path, transparent=True) # Use transparent to blend with the dark mode website
177
  plt.close()
178
-
179
  # Return the saved image
180
  return Image.open(save_path)
181
 
@@ -308,35 +324,45 @@ def classify_based_on_distance(train_data, train_labels, test_data):
308
 
309
  return torch.tensor(predictions) # Return predictions as a PyTorch tensor
310
 
311
- def plot_confusion_matrix(y_true, y_pred, title):
312
  cm = confusion_matrix(y_true, y_pred)
313
-
314
  # Calculate F1 Score
315
  f1 = f1_score(y_true, y_pred, average='weighted')
316
 
317
- plt.style.use('dark_background')
 
 
 
 
 
 
 
 
 
318
  plt.figure(figsize=(5, 5))
319
-
320
- # Plot the confusion matrix with a dark-mode compatible colormap
321
- sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
322
-
 
323
  # Add F1-score to the title
324
- plt.title(f"{title}\n(F1 Score: {f1:.3f})", color="white", fontsize=14)
325
-
326
- # Customize tick labels for dark mode
327
- plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
328
- plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
329
-
330
- plt.ylabel('True label', color="white", fontsize=12)
331
- plt.xlabel('Predicted label', color="white", fontsize=12)
332
  plt.tight_layout()
333
-
334
  # Save the plot as an image
335
- plt.savefig(f"{title}.png", transparent=True) # Use transparent to blend with the dark mode website
336
  plt.close()
337
-
338
  # Return the saved image
339
- return Image.open(f"{title}.png")
340
 
341
  def identical_train_test_split(output_emb, output_raw, labels, train_percentage):
342
  N = output_emb.shape[0]
 
61
  f1 = np.nan_to_num(f1) # Replace NaN with 0
62
  return np.mean(f1) # Return the mean F1-score across all classes
63
 
64
+ def plot_confusion_matrix_beamPred(cm, classes, title, save_path, light_mode=True):
65
  # Compute the average F1-score
66
  avg_f1 = compute_f1_score(cm)
67
 
68
+ # Choose the color scheme based on the mode
69
+ if light_mode:
70
+ plt.style.use('default') # Use default (light) mode styling
71
+ text_color = 'black'
72
+ cmap = 'Blues' # Light-mode-friendly colormap
73
+ else:
74
+ plt.style.use('dark_background') # Use dark mode styling
75
+ text_color = 'white'
76
+ cmap = 'cividis' # Dark-mode-friendly colormap
77
+
78
  plt.figure(figsize=(10, 10))
79
+
80
  # Plot the confusion matrix with a dark-mode compatible colormap
81
+ sns.heatmap(cm, cmap=cmap, cbar=True, linecolor='white', vmin=0, vmax=cm.max(), alpha=0.85)
82
+
 
83
  # Add F1-score to the title
84
+ plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color=text_color, fontsize=14)
85
+
86
  tick_marks = np.arange(len(classes))
87
+ plt.xticks(tick_marks, classes, color=text_color, fontsize=14) # Adjust text color based on the mode
88
+ plt.yticks(tick_marks, classes, color=text_color, fontsize=14) # Adjust text color based on the mode
89
 
90
+ plt.ylabel('True label', color=text_color, fontsize=14)
91
+ plt.xlabel('Predicted label', color=text_color, fontsize=14)
92
  plt.tight_layout()
93
+
94
  # Save the plot as an image
95
+ plt.savefig(save_path, transparent=True) # Transparent to blend with the site background
96
  plt.close()
97
+
98
  # Return the saved image
99
  return Image.open(save_path)
100
 
101
+
102
  def compute_average_confusion_matrix(folder):
103
  confusion_matrices = []
104
  max_num_labels = 0
 
148
  import seaborn as sns
149
 
150
  # Function to compute confusion matrix, F1-score and plot it with dark mode style
151
+ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path, light_mode=True):
152
  # Load CSV file
153
  data = pd.read_csv(csv_file_path)
154
 
 
161
 
162
  # Compute F1-score
163
  f1 = f1_score(y_true, y_pred, average='macro') # Macro-average F1-score
164
+
165
+ # Set styling based on light or dark mode
166
+ if light_mode:
167
+ plt.style.use('default') # Light mode styling
168
+ text_color = 'black'
169
+ cmap = 'Blues' # Light-mode-friendly colormap
170
+ else:
171
+ plt.style.use('dark_background') # Dark mode styling
172
+ text_color = 'white'
173
+ cmap = 'magma' # Dark-mode-friendly colormap
174
 
 
 
175
  plt.figure(figsize=(5, 5))
176
+
177
+ # Plot the confusion matrix with the chosen colormap
178
+ sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
179
+
180
  # Add F1-score to the title
181
+ plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=14)
182
+
183
+ # Customize tick labels for light/dark mode
184
+ plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
185
+ plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
186
+
187
+ plt.ylabel('True label', color=text_color, fontsize=12)
188
+ plt.xlabel('Predicted label', color=text_color, fontsize=12)
189
  plt.tight_layout()
190
+
191
  # Save the plot as an image
192
+ plt.savefig(save_path, transparent=True) # Use transparent to blend with the website
193
  plt.close()
194
+
195
  # Return the saved image
196
  return Image.open(save_path)
197
 
 
324
 
325
  return torch.tensor(predictions) # Return predictions as a PyTorch tensor
326
 
327
+ def plot_confusion_matrix(y_true, y_pred, title, save_path="confusion_matrix.png", light_mode=True):
328
  cm = confusion_matrix(y_true, y_pred)
329
+
330
  # Calculate F1 Score
331
  f1 = f1_score(y_true, y_pred, average='weighted')
332
 
333
+ # Choose the color scheme based on the mode
334
+ if light_mode:
335
+ plt.style.use('default') # Light mode styling
336
+ text_color = 'black'
337
+ cmap = 'Blues' # Light-mode-friendly colormap
338
+ else:
339
+ plt.style.use('dark_background') # Dark mode styling
340
+ text_color = 'white'
341
+ cmap = 'magma' # Dark-mode-friendly colormap
342
+
343
  plt.figure(figsize=(5, 5))
344
+
345
+ # Plot the confusion matrix with a colormap compatible for the mode
346
+ sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12},
347
+ linewidths=0.5, linecolor='white')
348
+
349
  # Add F1-score to the title
350
+ plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=14)
351
+
352
+ # Customize tick labels for light/dark mode
353
+ plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
354
+ plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
355
+
356
+ plt.ylabel('True label', color=text_color, fontsize=12)
357
+ plt.xlabel('Predicted label', color=text_color, fontsize=12)
358
  plt.tight_layout()
359
+
360
  # Save the plot as an image
361
+ plt.savefig(save_path, transparent=True) # Transparent to blend with website background
362
  plt.close()
363
+
364
  # Return the saved image
365
+ return Image.open(save_path)
366
 
367
  def identical_train_test_split(output_emb, output_raw, labels, train_percentage):
368
  N = output_emb.shape[0]