Sadjad Alikhani commited on
Commit
9167024
·
verified ·
1 Parent(s): 5bec428

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -26
app.py CHANGED
@@ -34,7 +34,7 @@ def beam_prediction_task(data_percentage, task_complexity):
34
  embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
35
  if embeddings_cm is not None:
36
  embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
37
- plot_confusion_matrix_beamPred(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=embeddings_cm_path)
38
  embeddings_img = Image.open(embeddings_cm_path)
39
  else:
40
  embeddings_img = None
@@ -71,7 +71,7 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
71
  plt.figure(figsize=(5, 5))
72
 
73
  # Plot the confusion matrix with a dark-mode compatible colormap
74
- sns.heatmap(cm, cmap="magma", cbar=False, linecolor='white')
75
 
76
  # Add F1-score to the title
77
  plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color="white", fontsize=14)
@@ -316,34 +316,31 @@ def plot_confusion_matrix(y_true, y_pred, title):
316
 
317
  # Calculate F1 Score
318
  f1 = f1_score(y_true, y_pred, average='weighted')
319
-
320
- # Plot the confusion matrix with dark mode colors
321
  plt.figure(figsize=(5, 5))
322
- plt.imshow(cm, interpolation='nearest', cmap='coolwarm') # Dark mode color scheme
323
- plt.title(f"{title}\nF1-Score: {f1:.2f}", color='white', pad=20) # Add padding for the title to prevent clipping
324
- plt.colorbar()
325
- plt.xticks([0, 1], labels=['Class 0', 'Class 1'], color='white') # White text for dark mode
326
- plt.yticks([0, 1], labels=['Class 0', 'Class 1'], color='white') # White text for dark mode
327
-
328
- # Annotate the confusion matrix
329
- thresh = cm.max() / 2
330
- for i in range(cm.shape[0]):
331
- for j in range(cm.shape[1]):
332
- plt.text(j, i, format(cm[i, j], 'd'),
333
- ha="center", va="center",
334
- color="white" if cm[i, j] > thresh else "black")
335
-
336
- plt.ylabel('True label', color='white') # White text for dark mode
337
- plt.xlabel('Predicted label', color='white') # White text for dark mode
338
- plt.tight_layout(pad=2.0) # Add padding to prevent clipping
339
-
340
- # Save the plot
341
- plt.savefig(f"{title}.png", facecolor='black') # Set background to black for dark mode
342
  plt.close()
343
 
 
344
  return Image.open(f"{title}.png")
345
-
346
-
347
  def identical_train_test_split(output_emb, output_raw, labels, percentage):
348
  N = output_emb.shape[0] # Get the total number of samples
349
 
 
34
  embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
35
  if embeddings_cm is not None:
36
  embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
37
+ plot_confusion_matrix_beamPred(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix\n({data_percentage}% data, {task_complexity} beams)", save_path=embeddings_cm_path)
38
  embeddings_img = Image.open(embeddings_cm_path)
39
  else:
40
  embeddings_img = None
 
71
  plt.figure(figsize=(5, 5))
72
 
73
  # Plot the confusion matrix with a dark-mode compatible colormap
74
+ sns.heatmap(cm, cmap="magma", cbar=True, linecolor='white')
75
 
76
  # Add F1-score to the title
77
  plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color="white", fontsize=14)
 
316
 
317
  # Calculate F1 Score
318
  f1 = f1_score(y_true, y_pred, average='weighted')
319
+
320
+ plt.style.use('dark_background')
321
  plt.figure(figsize=(5, 5))
322
+
323
+ # Plot the confusion matrix with a dark-mode compatible colormap
324
+ sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
325
+
326
+ # Add F1-score to the title
327
+ plt.title(f"{title}\n(F1 Score: {f1:.3f})", color="white", fontsize=14)
328
+
329
+ # Customize tick labels for dark mode
330
+ plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
331
+ plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
332
+
333
+ plt.ylabel('True label', color="white", fontsize=12)
334
+ plt.xlabel('Predicted label', color="white", fontsize=12)
335
+ plt.tight_layout()
336
+
337
+ # Save the plot as an image
338
+ plt.savefig(f"{title}.png", transparent=True) # Use transparent to blend with the dark mode website
 
 
 
339
  plt.close()
340
 
341
+ # Return the saved image
342
  return Image.open(f"{title}.png")
343
+
 
344
  def identical_train_test_split(output_emb, output_raw, labels, percentage):
345
  N = output_emb.shape[0] # Get the total number of samples
346