Sadjad Alikhani commited on
Commit
0034c8b
·
verified ·
1 Parent(s): f5ccec5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -71,14 +71,15 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
71
  plt.figure(figsize=(5, 5))
72
 
73
  # Plot the confusion matrix with a dark-mode compatible colormap
74
- sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
75
 
76
  # Add F1-score to the title
77
  plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color="white", fontsize=14)
78
 
79
- # Customize tick labels for dark mode
80
- plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
81
- plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
 
82
 
83
  plt.ylabel('True label', color="white", fontsize=12)
84
  plt.xlabel('Predicted label', color="white", fontsize=12)
@@ -92,6 +93,7 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
92
  return Image.open(save_path)
93
 
94
 
 
95
  def compute_average_confusion_matrix(folder):
96
  confusion_matrices = []
97
  max_num_labels = 0
 
71
  plt.figure(figsize=(5, 5))
72
 
73
  # Plot the confusion matrix with a dark-mode compatible colormap
74
+ sns.heatmap(cm, annot=True, fmt=".2f", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
75
 
76
  # Add F1-score to the title
77
  plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color="white", fontsize=14)
78
 
79
+ tick_marks = np.arange(len(classes))
80
+ plt.xticks(tick_marks, classes, color="white", fontsize=10) # White text for dark mode
81
+ plt.yticks(tick_marks, classes, color="white", fontsize=10) # White text for dark mode
82
+
83
 
84
  plt.ylabel('True label', color="white", fontsize=12)
85
  plt.xlabel('Predicted label', color="white", fontsize=12)
 
93
  return Image.open(save_path)
94
 
95
 
96
+
97
  def compute_average_confusion_matrix(folder):
98
  confusion_matrices = []
99
  max_num_labels = 0