Sadjad Alikhani commited on
Commit
c6aa746
·
verified ·
1 Parent(s): 68cc6b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -38
app.py CHANGED
@@ -18,6 +18,57 @@ EMBEDDINGS_PATH = os.path.join("images", "embeddings")
18
  # Specific values for percentage of data for training
19
  percentage_values = (np.arange(9) + 1)*10
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Custom class to capture print output
22
  class PrintCapture(io.StringIO):
23
  def __init__(self):
@@ -263,12 +314,6 @@ def process_hdf5_file(uploaded_file, percentage_idx):
263
 
264
  # Define the Gradio interface
265
  with gr.Blocks(css="""
266
- .vertical-slider input[type=range] {
267
- writing-mode: bt-lr; /* IE */
268
- -webkit-appearance: slider-vertical; /* WebKit */
269
- width: 8px;
270
- height: 200px;
271
- }
272
  .slider-container {
273
  display: inline-block;
274
  margin-right: 50px;
@@ -276,55 +321,39 @@ with gr.Blocks(css="""
276
  }
277
  """) as demo:
278
 
279
- # Contact Section
280
- gr.Markdown("""
281
- <div style="text-align: center;">
282
- <a target="_blank" href="https://www.wi-lab.net">
283
- <img src="https://www.wi-lab.net/wp-content/uploads/2021/08/WI-name.png" alt="Wireless Model" style="height: 30px;">
284
- </a>
285
- <a target="_blank" href="mailto:alikhani@asu.edu" style="margin-left: 10px;">
286
- <img src="https://img.shields.io/badge/email-alikhani@asu.edu-blue.svg?logo=gmail" alt="Email">
287
- </a>
288
- </div>
289
- """)
290
-
291
- # Tabs for Beam Prediction and LoS/NLoS Classification
292
  with gr.Tab("Beam Prediction Task"):
293
  gr.Markdown("### Beam Prediction Task")
294
 
295
  with gr.Row():
296
- with gr.Column(elem_id="slider-container"):
297
- gr.Markdown("Percentage of Data for Training")
298
- percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
299
 
300
  with gr.Row():
301
- raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
302
- embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
303
 
304
- percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
 
 
305
 
 
306
  with gr.Tab("LoS/NLoS Classification Task"):
307
  gr.Markdown("### LoS/NLoS Classification Task")
308
-
309
  file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
310
 
311
  with gr.Row():
312
- with gr.Column(elem_id="slider-container"):
313
- gr.Markdown("Percentage of Data for Training")
314
- #percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
315
- percentage_dropdown_los = gr.Dropdown(choices=[0, 1, 2, 3, 4, 5, 6, 7, 8],
316
- value=0,
317
- label="Percentage of Data for Training",
318
- interactive=True)
319
-
320
  with gr.Row():
321
- raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
322
- embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
323
  output_textbox = gr.Textbox(label="Console Output", lines=10)
324
 
 
325
  file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_dropdown_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
326
- percentage_dropdown_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_dropdown_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
327
 
328
  # Launch the app
329
  if __name__ == "__main__":
330
- demo.launch()
 
 
18
  # Specific values for percentage of data for training
19
  percentage_values = (np.arange(9) + 1)*10
20
 
21
+
22
+
23
+
24
+ def beam_prediction_task(data_percentage, task_complexity):
25
+ # Folder naming convention based on input_type, data_percentage, and task_complexity
26
+ raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
27
+ embeddings_folder = f"images/embeddings_{data_percentage/100:.1f}_{task_complexity}"
28
+
29
+ # Process raw confusion matrix
30
+ raw_cm = compute_average_confusion_matrix(raw_folder)
31
+ if raw_cm is not None:
32
+ raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
33
+ plot_confusion_matrix(raw_cm, classes=np.arange(raw_cm.shape[0]), title=f"Raw Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=raw_cm_path)
34
+ raw_img = Image.open(raw_cm_path)
35
+ else:
36
+ raw_img = None
37
+
38
+ # Process embeddings confusion matrix
39
+ embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
40
+ if embeddings_cm is not None:
41
+ embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
42
+ plot_confusion_matrix(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=embeddings_cm_path)
43
+ embeddings_img = Image.open(embeddings_cm_path)
44
+ else:
45
+ embeddings_img = None
46
+
47
+ return raw_img, embeddings_img
48
+
49
+
50
+ # Function to compute the average confusion matrix across CSV files in a folder
51
+ def compute_average_confusion_matrix(folder):
52
+ confusion_matrices = []
53
+ for file in os.listdir(folder):
54
+ if file.endswith(".csv"):
55
+ data = pd.read_csv(os.path.join(folder, file))
56
+ y_true = data["Target"]
57
+ y_pred = data["Top-1 Prediction"]
58
+ num_labels = len(np.unique(y_true))
59
+ cm = confusion_matrix(y_true, y_pred, labels=np.arange(num_labels))
60
+ confusion_matrices.append(cm)
61
+
62
+ if confusion_matrices:
63
+ avg_cm = np.mean(confusion_matrices, axis=0)
64
+ return avg_cm
65
+ else:
66
+ return None
67
+
68
+
69
+
70
+
71
+
72
  # Custom class to capture print output
73
  class PrintCapture(io.StringIO):
74
  def __init__(self):
 
314
 
315
  # Define the Gradio interface
316
  with gr.Blocks(css="""
 
 
 
 
 
 
317
  .slider-container {
318
  display: inline-block;
319
  margin-right: 50px;
 
321
  }
322
  """) as demo:
323
 
324
+ # Tab for Beam Prediction Task
 
 
 
 
 
 
 
 
 
 
 
 
325
  with gr.Tab("Beam Prediction Task"):
326
  gr.Markdown("### Beam Prediction Task")
327
 
328
  with gr.Row():
329
+ with gr.Column():
330
+ data_percentage_slider = gr.Slider(label="Data Percentage for Training", minimum=10, maximum=100, step=10, value=10)
331
+ task_complexity_slider = gr.Slider(label="Task Complexity (Number of Beams)", minimum=16, maximum=256, value=16, choices=[16, 32, 64, 128, 256])
332
 
333
  with gr.Row():
334
+ raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
335
+ embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300)
336
 
337
+ # Update the confusion matrices whenever sliders change
338
+ data_percentage_slider.change(fn=beam_prediction_task, inputs=[data_percentage_slider, task_complexity_slider], outputs=[raw_img_bp, embeddings_img_bp])
339
+ task_complexity_slider.change(fn=beam_prediction_task, inputs=[data_percentage_slider, task_complexity_slider], outputs=[raw_img_bp, embeddings_img_bp])
340
 
341
+ # Separate Tab for LoS/NLoS Classification Task
342
  with gr.Tab("LoS/NLoS Classification Task"):
343
  gr.Markdown("### LoS/NLoS Classification Task")
 
344
  file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
345
 
346
  with gr.Row():
347
+ percentage_dropdown_los = gr.Dropdown(choices=[0, 1, 2, 3, 4, 5, 6, 7, 8], value=0, label="Percentage of Data for Training")
 
 
 
 
 
 
 
348
  with gr.Row():
349
+ raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
350
+ embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
351
  output_textbox = gr.Textbox(label="Console Output", lines=10)
352
 
353
+ # Placeholder for LoS/NLoS classification function (already implemented in your previous code)
354
  file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_dropdown_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
 
355
 
356
  # Launch the app
357
  if __name__ == "__main__":
358
+ demo.launch()
359
+