Sadjad Alikhani commited on
Commit
122a1ed
·
verified ·
1 Parent(s): b6261fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -42
app.py CHANGED
@@ -62,7 +62,6 @@ def compute_f1_score(cm):
62
  f1 = np.nan_to_num(f1) # Replace NaN with 0
63
  return np.mean(f1) # Return the mean F1-score across all classes
64
 
65
- # Function to plot and save confusion matrix with F1-score in the title
66
  def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
67
  # Compute the average F1-score
68
  avg_f1 = compute_f1_score(cm)
@@ -70,20 +69,22 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
70
  # Update title to include average F1-score
71
  full_title = f"{title} (Avg F1-Score: {avg_f1:.2f})"
72
 
73
- # Plot the confusion matrix
74
  plt.figure(figsize=(8, 6))
75
- plt.imshow(cm, interpolation='nearest', cmap='coolwarm')
76
- plt.title(full_title)
77
  plt.colorbar()
78
 
79
  tick_marks = np.arange(len(classes))
80
- plt.xticks(tick_marks, classes, rotation=45)
81
- plt.yticks(tick_marks, classes)
82
 
83
- plt.tight_layout()
84
- plt.ylabel('True label')
85
- plt.xlabel('Predicted label')
86
- plt.savefig(save_path)
 
 
87
  plt.close()
88
 
89
  def compute_average_confusion_matrix(folder):
@@ -135,7 +136,6 @@ percentage_values_los = np.linspace(0.001, 1, 20) * 100 # 20 percentage values
135
  from sklearn.metrics import f1_score
136
  import seaborn as sns
137
 
138
- # Function to compute confusion matrix, F1-score and plot it with dark mode style
139
  def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
140
  # Load CSV file
141
  data = pd.read_csv(csv_file_path)
@@ -147,29 +147,30 @@ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
147
  # Compute confusion matrix
148
  cm = confusion_matrix(y_true, y_pred)
149
 
150
- # Compute F1-score
151
- f1 = f1_score(y_true, y_pred, average='macro') # Macro-average F1-score
152
-
153
- # Set dark mode styling
154
- plt.style.use('dark_background')
155
  plt.figure(figsize=(5, 5))
 
 
 
 
 
156
 
157
- # Plot the confusion matrix with a dark-mode compatible colormap
158
- sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
159
-
160
- # Add F1-score to the title
161
- plt.title(f"{title} (F1 Score: {f1:.3f})", color="white", fontsize=14)
162
-
163
- # Customize tick labels for dark mode
164
- plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
165
- plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
166
 
167
- plt.ylabel('True label', color="white", fontsize=12)
168
- plt.xlabel('Predicted label', color="white", fontsize=12)
169
- plt.tight_layout()
170
 
171
  # Save the plot as an image
172
- plt.savefig(save_path, transparent=True) # Use transparent to blend with the dark mode website
173
  plt.close()
174
 
175
  # Return the saved image
@@ -304,32 +305,39 @@ def classify_based_on_distance(train_data, train_labels, test_data):
304
 
305
  return torch.tensor(predictions) # Return predictions as a PyTorch tensor
306
 
307
- # Function to generate confusion matrix plot
308
  def plot_confusion_matrix(y_true, y_pred, title):
309
  cm = confusion_matrix(y_true, y_pred)
 
 
 
 
 
310
  plt.figure(figsize=(5, 5))
311
- plt.imshow(cm, cmap='Blues')
312
- plt.title(title)
313
- plt.xlabel('Predicted')
314
- plt.ylabel('Actual')
315
  plt.colorbar()
316
-
317
- # Add labels for x and y ticks (Actual/Predicted class labels)
318
- plt.xticks([0, 1], labels=[0, 1])
319
- plt.yticks([0, 1], labels=[0, 1])
320
-
321
  # Annotate the confusion matrix
322
- thresh = cm.max() / 2 # Define threshold to choose text color (black or white)
323
  for i in range(cm.shape[0]):
324
  for j in range(cm.shape[1]):
325
  plt.text(j, i, format(cm[i, j], 'd'),
326
  ha="center", va="center",
327
  color="white" if cm[i, j] > thresh else "black")
328
 
329
- plt.tight_layout()
330
- plt.savefig(f"{title}.png")
 
 
 
 
 
 
331
  return Image.open(f"{title}.png")
332
 
 
333
  def identical_train_test_split(output_emb, output_raw, labels, percentage):
334
  N = output_emb.shape[0] # Get the total number of samples
335
 
 
62
  f1 = np.nan_to_num(f1) # Replace NaN with 0
63
  return np.mean(f1) # Return the mean F1-score across all classes
64
 
 
65
  def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
66
  # Compute the average F1-score
67
  avg_f1 = compute_f1_score(cm)
 
69
  # Update title to include average F1-score
70
  full_title = f"{title} (Avg F1-Score: {avg_f1:.2f})"
71
 
72
+ # Plot the confusion matrix with dark mode adjustments
73
  plt.figure(figsize=(8, 6))
74
+ plt.imshow(cm, interpolation='nearest', cmap='coolwarm') # Dark mode color scheme
75
+ plt.title(full_title, color='white', pad=20) # Add padding to prevent title clipping, white text for dark mode
76
  plt.colorbar()
77
 
78
  tick_marks = np.arange(len(classes))
79
+ plt.xticks(tick_marks, classes, rotation=45, color='white') # White text for dark mode
80
+ plt.yticks(tick_marks, classes, color='white') # White text for dark mode
81
 
82
+ plt.tight_layout(pad=2.0) # Add padding to prevent axis label clipping
83
+ plt.ylabel('True label', color='white') # White text for dark mode
84
+ plt.xlabel('Predicted label', color='white') # White text for dark mode
85
+
86
+ # Save the plot with a black background for dark mode
87
+ plt.savefig(save_path, facecolor='black')
88
  plt.close()
89
 
90
  def compute_average_confusion_matrix(folder):
 
136
  from sklearn.metrics import f1_score
137
  import seaborn as sns
138
 
 
139
  def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
140
  # Load CSV file
141
  data = pd.read_csv(csv_file_path)
 
147
  # Compute confusion matrix
148
  cm = confusion_matrix(y_true, y_pred)
149
 
150
+ # Calculate F1 Score
151
+ f1 = f1_score(y_true, y_pred, average='weighted')
152
+
153
+ # Plot the confusion matrix with dark mode colors
 
154
  plt.figure(figsize=(5, 5))
155
+ plt.imshow(cm, interpolation='nearest', cmap='coolwarm') # Dark mode color scheme
156
+ plt.title(f"{title}\nF1-Score: {f1:.2f}", color='white', pad=20) # Display F1-Score in title, add padding for better visibility
157
+ plt.colorbar()
158
+ plt.xticks([0, 1], labels=['Class 0', 'Class 1'], color='white') # White text for dark mode
159
+ plt.yticks([0, 1], labels=['Class 0', 'Class 1'], color='white') # White text for dark mode
160
 
161
+ # Annotate the confusion matrix
162
+ thresh = cm.max() / 2
163
+ for i in range(cm.shape[0]):
164
+ for j in range(cm.shape[1]):
165
+ plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",
166
+ color="white" if cm[i, j] > thresh else "black")
 
 
 
167
 
168
+ plt.ylabel('True label', color='white') # White text for dark mode
169
+ plt.xlabel('Predicted label', color='white') # White text for dark mode
170
+ plt.tight_layout(pad=2.0) # Add padding to prevent clipping of labels
171
 
172
  # Save the plot as an image
173
+ plt.savefig(save_path, facecolor='black') # Set background to black for dark mode
174
  plt.close()
175
 
176
  # Return the saved image
 
305
 
306
  return torch.tensor(predictions) # Return predictions as a PyTorch tensor
307
 
 
308
  def plot_confusion_matrix(y_true, y_pred, title):
309
  cm = confusion_matrix(y_true, y_pred)
310
+
311
+ # Calculate F1 Score
312
+ f1 = f1_score(y_true, y_pred, average='weighted')
313
+
314
+ # Plot the confusion matrix with dark mode colors
315
  plt.figure(figsize=(5, 5))
316
+ plt.imshow(cm, interpolation='nearest', cmap='coolwarm') # Dark mode color scheme
317
+ plt.title(f"{title}\nF1-Score: {f1:.2f}", color='white', pad=20) # Add padding for the title to prevent clipping
 
 
318
  plt.colorbar()
319
+ plt.xticks([0, 1], labels=['Class 0', 'Class 1'], color='white') # White text for dark mode
320
+ plt.yticks([0, 1], labels=['Class 0', 'Class 1'], color='white') # White text for dark mode
321
+
 
 
322
  # Annotate the confusion matrix
323
+ thresh = cm.max() / 2
324
  for i in range(cm.shape[0]):
325
  for j in range(cm.shape[1]):
326
  plt.text(j, i, format(cm[i, j], 'd'),
327
  ha="center", va="center",
328
  color="white" if cm[i, j] > thresh else "black")
329
 
330
+ plt.ylabel('True label', color='white') # White text for dark mode
331
+ plt.xlabel('Predicted label', color='white') # White text for dark mode
332
+ plt.tight_layout(pad=2.0) # Add padding to prevent clipping
333
+
334
+ # Save the plot
335
+ plt.savefig(f"{title}.png", facecolor='black') # Set background to black for dark mode
336
+ plt.close()
337
+
338
  return Image.open(f"{title}.png")
339
 
340
+
341
  def identical_train_test_split(output_emb, output_raw, labels, percentage):
342
  N = output_emb.shape[0] # Get the total number of samples
343