Sadjad Alikhani commited on
Commit
b6a7e6d
·
verified ·
1 Parent(s): 700d590

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -49
app.py CHANGED
@@ -12,16 +12,8 @@ from sklearn.metrics import confusion_matrix
12
  import matplotlib.pyplot as plt
13
  import pandas as pd
14
 
15
- # Paths to the predefined images folder
16
- RAW_PATH = os.path.join("images", "raw")
17
- EMBEDDINGS_PATH = os.path.join("images", "embeddings")
18
-
19
- # Specific values for percentage of data for training
20
- percentage_values = (np.arange(9) + 1)*10
21
-
22
-
23
-
24
 
 
25
  def beam_prediction_task(data_percentage, task_complexity):
26
  # Folder naming convention based on input_type, data_percentage, and task_complexity
27
  raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
@@ -92,40 +84,6 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
92
  plt.savefig(save_path)
93
  plt.close()
94
 
95
-
96
- #def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
97
- # plt.figure(figsize=(8, 6))
98
- # plt.imshow(cm, interpolation='nearest', cmap='coolwarm')
99
- # plt.title(title)
100
- # plt.colorbar()
101
- # tick_marks = np.arange(len(classes))
102
- # plt.xticks(tick_marks, classes, rotation=45)
103
- # plt.yticks(tick_marks, classes)
104
- #
105
- # plt.tight_layout()
106
- # plt.ylabel('True label')
107
- # plt.xlabel('Predicted label')
108
- # plt.savefig(save_path)
109
- # plt.close()
110
-
111
- # Function to compute the average confusion matrix across CSV files in a folder
112
- #def compute_average_confusion_matrix(folder):
113
- # confusion_matrices = []
114
- # for file in os.listdir(folder):
115
- # if file.endswith(".csv"):
116
- # data = pd.read_csv(os.path.join(folder, file))
117
- # y_true = data["Target"]
118
- # y_pred = data["Top-1 Prediction"]
119
- # num_labels = len(np.unique(y_true))
120
- # cm = confusion_matrix(y_true, y_pred, labels=np.arange(num_labels))
121
- # confusion_matrices.append(cm)
122
- #
123
- # if confusion_matrices:
124
- # avg_cm = np.mean(confusion_matrices, axis=0)
125
- # return avg_cm
126
- # else:
127
- # return None
128
-
129
  def compute_average_confusion_matrix(folder):
130
  confusion_matrices = []
131
  max_num_labels = 0
@@ -162,10 +120,99 @@ def compute_average_confusion_matrix(folder):
162
  else:
163
  return None
164
 
 
165
 
166
 
 
 
167
 
 
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  # Custom class to capture print output
171
  class PrintCapture(io.StringIO):
@@ -410,7 +457,7 @@ def process_hdf5_file(uploaded_file, percentage_idx):
410
  os.chdir(original_dir)
411
  sys.stdout = sys.__stdout__ # Reset print statements
412
 
413
- # Define the Gradio interface
414
  with gr.Blocks(css="""
415
  .slider-container {
416
  display: inline-block;
@@ -439,17 +486,35 @@ with gr.Blocks(css="""
439
  # Separate Tab for LoS/NLoS Classification Task
440
  with gr.Tab("LoS/NLoS Classification Task"):
441
  gr.Markdown("### LoS/NLoS Classification Task")
442
- file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
443
 
444
- with gr.Row():
445
- percentage_dropdown_los = gr.Dropdown(choices=[0, 1, 2, 3, 4, 5, 6, 7, 8], value=0, label="Percentage of Data for Training")
 
 
 
 
 
 
 
 
446
  with gr.Row():
447
  raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
448
  embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
449
  output_textbox = gr.Textbox(label="Console Output", lines=10)
450
 
451
- # Placeholder for LoS/NLoS classification function (already implemented in your previous code)
452
- file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_dropdown_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
 
 
 
 
 
 
 
 
 
 
 
453
 
454
  # Launch the app
455
  if __name__ == "__main__":
 
12
  import matplotlib.pyplot as plt
13
  import pandas as pd
14
 
 
 
 
 
 
 
 
 
 
15
 
16
+ #################### BEAM PREDICTION #########################}
17
  def beam_prediction_task(data_percentage, task_complexity):
18
  # Folder naming convention based on input_type, data_percentage, and task_complexity
19
  raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
 
84
  plt.savefig(save_path)
85
  plt.close()
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def compute_average_confusion_matrix(folder):
88
  confusion_matrices = []
89
  max_num_labels = 0
 
120
  else:
121
  return None
122
 
123
+ ########################## LOS/NLOS CLASSIFICATION #############################3
124
 
125
 
126
+ # Paths to the predefined images folder
127
+ LOS_PATH = "images_LoS"
128
 
129
+ # Define the percentage values
130
+ percentage_values_los = np.linspace(0.1, 1, 20) * 100 # 20 percentage values
131
 
132
+ # Function to compute confusion matrix and plot it
133
+ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
134
+ # Load CSV file
135
+ data = pd.read_csv(csv_file_path)
136
+
137
+ # Extract ground truth and predictions
138
+ y_true = data['ground-truth']
139
+ y_pred = data['predicted']
140
+
141
+ # Compute confusion matrix
142
+ cm = confusion_matrix(y_true, y_pred)
143
+
144
+ # Plot the confusion matrix
145
+ plt.figure(figsize=(5, 5))
146
+ plt.imshow(cm, interpolation='nearest', cmap='Blues')
147
+ plt.title(title)
148
+ plt.colorbar()
149
+ plt.xticks([0, 1], labels=['Class 0', 'Class 1'])
150
+ plt.yticks([0, 1], labels=['Class 0', 'Class 1'])
151
+
152
+ # Annotate the confusion matrix
153
+ thresh = cm.max() / 2
154
+ for i in range(cm.shape[0]):
155
+ for j in range(cm.shape[1]):
156
+ plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",
157
+ color="white" if cm[i, j] > thresh else "black")
158
+
159
+ plt.ylabel('True label')
160
+ plt.xlabel('Predicted label')
161
+ plt.tight_layout()
162
+
163
+ # Save the plot as an image
164
+ plt.savefig(save_path)
165
+ plt.close()
166
+
167
+ # Return the saved image
168
+ return Image.open(save_path)
169
+
170
+ # Function to load confusion matrix based on percentage and input_type
171
+ def display_confusion_matrices_los(percentage_idx):
172
+ percentage = percentage_values_los[percentage_idx]
173
+
174
+ # Construct folder names
175
+ raw_folder = os.path.join(LOS_PATH, f"raw_{percentage/100:.3f}_los_noTraining")
176
+ embeddings_folder = os.path.join(LOS_PATH, f"embedding_{percentage/100:.3f}_los_noTraining")
177
+
178
+ # Process raw confusion matrix
179
+ raw_csv_file = os.path.join(raw_folder, "confusion_matrix.csv")
180
+ raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
181
+ raw_img = plot_confusion_matrix_from_csv(raw_csv_file,
182
+ f"Raw Confusion Matrix ({percentage:.1f}% data)",
183
+ raw_cm_img_path)
184
+
185
+ # Process embeddings confusion matrix
186
+ embeddings_csv_file = os.path.join(embeddings_folder, "confusion_matrix.csv")
187
+ embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
188
+ embeddings_img = plot_confusion_matrix_from_csv(embeddings_csv_file,
189
+ f"Embeddings Confusion Matrix ({percentage:.1f}% data)",
190
+ embeddings_cm_img_path)
191
+
192
+ return raw_img, embeddings_img
193
+
194
+ # Main function to handle user choice
195
+ def handle_user_choice(choice, percentage_idx=None, uploaded_file=None):
196
+ if choice == "Use Predefined Data":
197
+ return display_confusion_matrices_los(percentage_idx)
198
+ elif choice == "Upload Dataset":
199
+ if uploaded_file is not None:
200
+ return process_hdf5_file(uploaded_file, percentage_idx)
201
+ else:
202
+ return "Please upload a dataset", "Please upload a dataset"
203
+ else:
204
+ return "Invalid choice", "Invalid choice"
205
+
206
+
207
+
208
+
209
+
210
+
211
+
212
+
213
+
214
+
215
+
216
 
217
  # Custom class to capture print output
218
  class PrintCapture(io.StringIO):
 
457
  os.chdir(original_dir)
458
  sys.stdout = sys.__stdout__ # Reset print statements
459
 
460
+ ######################## Define the Gradio interface ###############################
461
  with gr.Blocks(css="""
462
  .slider-container {
463
  display: inline-block;
 
486
  # Separate Tab for LoS/NLoS Classification Task
487
  with gr.Tab("LoS/NLoS Classification Task"):
488
  gr.Markdown("### LoS/NLoS Classification Task")
 
489
 
490
+ # Radio button for user choice: predefined data or upload dataset
491
+ choice_radio = gr.Radio(choices=["Use Predefined Data", "Upload Dataset"], label="Choose how to proceed", value="Use Predefined Data")
492
+
493
+ # Dropdown for selecting percentage for predefined data
494
+ percentage_dropdown_los = gr.Dropdown(choices=list(range(20)), value=0, label="Percentage of Data for Training")
495
+
496
+ # File uploader for dataset (only visible if user chooses to upload a dataset)
497
+ file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"], visible=False)
498
+
499
+ # Confusion matrices display
500
  with gr.Row():
501
  raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
502
  embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
503
  output_textbox = gr.Textbox(label="Console Output", lines=10)
504
 
505
+ # Update the file uploader visibility based on user choice
506
+ def toggle_file_input(choice):
507
+ return gr.update(visible=(choice == "Upload Dataset"))
508
+
509
+ choice_radio.change(fn=toggle_file_input, inputs=[choice_radio], outputs=file_input)
510
+
511
+ # When user makes a choice, update the display
512
+ choice_radio.change(fn=handle_user_choice, inputs=[choice_radio, percentage_dropdown_los, file_input],
513
+ outputs=[raw_img_los, embeddings_img_los, output_textbox])
514
+
515
+ # When percentage slider changes (for predefined data)
516
+ percentage_dropdown_los.change(fn=handle_user_choice, inputs=[choice_radio, percentage_dropdown_los, file_input],
517
+ outputs=[raw_img_los, embeddings_img_los, output_textbox])
518
 
519
  # Launch the app
520
  if __name__ == "__main__":