Sadjad Alikhani commited on
Commit
f5ccec5
·
verified ·
1 Parent(s): 86755ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -65,27 +65,32 @@ def compute_f1_score(cm):
65
  def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
66
  # Compute the average F1-score
67
  avg_f1 = compute_f1_score(cm)
 
 
 
 
68
 
69
- # Update title to include average F1-score
70
- full_title = f"{title} (Avg F1-Score: {avg_f1:.2f})"
71
 
72
- # Plot the confusion matrix with dark mode adjustments
73
- plt.figure(figsize=(8, 6))
74
- plt.imshow(cm, interpolation='nearest', cmap='coolwarm') # Dark mode color scheme
75
- plt.title(full_title, color='white', pad=20) # Add padding to prevent title clipping, white text for dark mode
76
- plt.colorbar()
77
 
78
- tick_marks = np.arange(len(classes))
79
- plt.xticks(tick_marks, classes, rotation=45, color='white') # White text for dark mode
80
- plt.yticks(tick_marks, classes, color='white') # White text for dark mode
81
-
82
- plt.tight_layout(pad=2.0) # Add padding to prevent axis label clipping
83
- plt.ylabel('True label', color='white') # White text for dark mode
84
- plt.xlabel('Predicted label', color='white') # White text for dark mode
85
-
86
- # Save the plot with a black background for dark mode
87
- plt.savefig(save_path, facecolor='black')
88
  plt.close()
 
 
 
 
89
 
90
  def compute_average_confusion_matrix(folder):
91
  confusion_matrices = []
@@ -158,7 +163,7 @@ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
158
  sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
159
 
160
  # Add F1-score to the title
161
- plt.title(f"{title} (F1 Score: {f1:.3f})", color="white", fontsize=14)
162
 
163
  # Customize tick labels for dark mode
164
  plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
 
65
  def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
66
  # Compute the average F1-score
67
  avg_f1 = compute_f1_score(cm)
68
+
69
+ # Set dark mode styling
70
+ plt.style.use('dark_background')
71
+ plt.figure(figsize=(5, 5))
72
 
73
+ # Plot the confusion matrix with a dark-mode compatible colormap
74
+ sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
75
 
76
+ # Add F1-score to the title
77
+ plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color="white", fontsize=14)
 
 
 
78
 
79
+ # Customize tick labels for dark mode
80
+ plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
81
+ plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
82
+
83
+ plt.ylabel('True label', color="white", fontsize=12)
84
+ plt.xlabel('Predicted label', color="white", fontsize=12)
85
+ plt.tight_layout()
86
+
87
+ # Save the plot as an image
88
+ plt.savefig(save_path, transparent=True) # Use transparent to blend with the dark mode website
89
  plt.close()
90
+
91
+ # Return the saved image
92
+ return Image.open(save_path)
93
+
94
 
95
  def compute_average_confusion_matrix(folder):
96
  confusion_matrices = []
 
163
  sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
164
 
165
  # Add F1-score to the title
166
+ plt.title(f"{title}\n(F1 Score: {f1:.3f})", color="white", fontsize=14)
167
 
168
  # Customize tick labels for dark mode
169
  plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)