Sadjad Alikhani commited on
Commit
244e2b5
·
verified ·
1 Parent(s): c565027

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -130,6 +130,10 @@ LOS_PATH = "images_LoS"
130
  percentage_values_los = np.linspace(0.001, 1, 20) * 100 # 20 percentage values
131
 
132
  # Function to compute confusion matrix and plot it
 
 
 
 
133
  def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
134
  # Load CSV file
135
  data = pd.read_csv(csv_file_path)
@@ -141,27 +145,29 @@ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
141
  # Compute confusion matrix
142
  cm = confusion_matrix(y_true, y_pred)
143
 
144
- # Plot the confusion matrix
 
 
 
 
145
  plt.figure(figsize=(5, 5))
146
- plt.imshow(cm, interpolation='nearest', cmap='Blues')
147
- plt.title(title)
148
- plt.colorbar()
149
- plt.xticks([0, 1], labels=['Class 0', 'Class 1'])
150
- plt.yticks([0, 1], labels=['Class 0', 'Class 1'])
151
 
152
- # Annotate the confusion matrix
153
- thresh = cm.max() / 2
154
- for i in range(cm.shape[0]):
155
- for j in range(cm.shape[1]):
156
- plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",
157
- color="white" if cm[i, j] > thresh else "black")
158
 
159
- plt.ylabel('True label')
160
- plt.xlabel('Predicted label')
 
 
 
 
 
 
 
161
  plt.tight_layout()
162
 
163
  # Save the plot as an image
164
- plt.savefig(save_path)
165
  plt.close()
166
 
167
  # Return the saved image
@@ -473,7 +479,8 @@ with gr.Blocks(css="""
473
  choice_radio = gr.Radio(choices=["Use Default Dataset", "Upload Dataset"], label="Choose how to proceed", value="Use Default Dataset")
474
 
475
  # Dropdown for selecting percentage for predefined data
476
- percentage_dropdown_los = gr.Dropdown(choices=[f"{value:.3f}" for value in percentage_values_los], value=f"{percentage_values_los[0]:.3f}", label="Percentage of Data for Training")
 
477
 
478
  # File uploader for dataset (only visible if user chooses to upload a dataset)
479
  file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"], visible=False)
 
130
  percentage_values_los = np.linspace(0.001, 1, 20) * 100 # 20 percentage values
131
 
132
  # Function to compute confusion matrix and plot it
133
+ from sklearn.metrics import f1_score
134
+ import seaborn as sns
135
+
136
+ # Function to compute confusion matrix, F1-score and plot it with dark mode style
137
  def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
138
  # Load CSV file
139
  data = pd.read_csv(csv_file_path)
 
145
  # Compute confusion matrix
146
  cm = confusion_matrix(y_true, y_pred)
147
 
148
+ # Compute F1-score
149
+ f1 = f1_score(y_true, y_pred, average='macro') # Macro-average F1-score
150
+
151
+ # Set dark mode styling
152
+ plt.style.use('dark_background')
153
  plt.figure(figsize=(5, 5))
 
 
 
 
 
154
 
155
+ # Plot the confusion matrix with a dark-mode compatible colormap
156
+ sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
 
 
 
 
157
 
158
+ # Add F1-score to the title
159
+ plt.title(f"{title} (F1 Score: {f1:.3f})", color="white", fontsize=14)
160
+
161
+ # Customize tick labels for dark mode
162
+ plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
163
+ plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
164
+
165
+ plt.ylabel('True label', color="white", fontsize=12)
166
+ plt.xlabel('Predicted label', color="white", fontsize=12)
167
  plt.tight_layout()
168
 
169
  # Save the plot as an image
170
+ plt.savefig(save_path, transparent=True) # Use transparent to blend with the dark mode website
171
  plt.close()
172
 
173
  # Return the saved image
 
479
  choice_radio = gr.Radio(choices=["Use Default Dataset", "Upload Dataset"], label="Choose how to proceed", value="Use Default Dataset")
480
 
481
  # Dropdown for selecting percentage for predefined data
482
+ #percentage_dropdown_los = gr.Dropdown(choices=[f"{value:.3f}" for value in percentage_values_los], value=f"{percentage_values_los[0]:.3f}", label="Percentage of Data for Training")
483
+ percentage_dropdown_los = gr.Dropdown(choices=list(range(20)), value=0, label="Percentage of Data for Training")
484
 
485
  # File uploader for dataset (only visible if user chooses to upload a dataset)
486
  file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"], visible=False)