Spaces:
Running
Running
Sadjad Alikhani
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -65,27 +65,32 @@ def compute_f1_score(cm):
|
|
65 |
def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
66 |
# Compute the average F1-score
|
67 |
avg_f1 = compute_f1_score(cm)
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
#
|
70 |
-
|
71 |
|
72 |
-
#
|
73 |
-
plt.
|
74 |
-
plt.imshow(cm, interpolation='nearest', cmap='coolwarm') # Dark mode color scheme
|
75 |
-
plt.title(full_title, color='white', pad=20) # Add padding to prevent title clipping, white text for dark mode
|
76 |
-
plt.colorbar()
|
77 |
|
78 |
-
|
79 |
-
plt.xticks(
|
80 |
-
plt.yticks(
|
81 |
-
|
82 |
-
plt.
|
83 |
-
plt.
|
84 |
-
plt.
|
85 |
-
|
86 |
-
# Save the plot
|
87 |
-
plt.savefig(save_path,
|
88 |
plt.close()
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def compute_average_confusion_matrix(folder):
|
91 |
confusion_matrices = []
|
@@ -158,7 +163,7 @@ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
|
|
158 |
sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
159 |
|
160 |
# Add F1-score to the title
|
161 |
-
plt.title(f"{title}
|
162 |
|
163 |
# Customize tick labels for dark mode
|
164 |
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
|
|
|
65 |
def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
66 |
# Compute the average F1-score
|
67 |
avg_f1 = compute_f1_score(cm)
|
68 |
+
|
69 |
+
# Set dark mode styling
|
70 |
+
plt.style.use('dark_background')
|
71 |
+
plt.figure(figsize=(5, 5))
|
72 |
|
73 |
+
# Plot the confusion matrix with a dark-mode compatible colormap
|
74 |
+
sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
75 |
|
76 |
+
# Add F1-score to the title
|
77 |
+
plt.title(f"{title}\n(F1 Score: {avg_f1:.3f})", color="white", fontsize=14)
|
|
|
|
|
|
|
78 |
|
79 |
+
# Customize tick labels for dark mode
|
80 |
+
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
|
81 |
+
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
|
82 |
+
|
83 |
+
plt.ylabel('True label', color="white", fontsize=12)
|
84 |
+
plt.xlabel('Predicted label', color="white", fontsize=12)
|
85 |
+
plt.tight_layout()
|
86 |
+
|
87 |
+
# Save the plot as an image
|
88 |
+
plt.savefig(save_path, transparent=True) # Use transparent to blend with the dark mode website
|
89 |
plt.close()
|
90 |
+
|
91 |
+
# Return the saved image
|
92 |
+
return Image.open(save_path)
|
93 |
+
|
94 |
|
95 |
def compute_average_confusion_matrix(folder):
|
96 |
confusion_matrices = []
|
|
|
163 |
sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
164 |
|
165 |
# Add F1-score to the title
|
166 |
+
plt.title(f"{title}\n(F1 Score: {f1:.3f})", color="white", fontsize=14)
|
167 |
|
168 |
# Customize tick labels for dark mode
|
169 |
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
|