wi-lab commited on
Commit
64d3617
·
verified ·
1 Parent(s): dcec41c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -36
app.py CHANGED
@@ -324,49 +324,36 @@ def classify_based_on_distance(train_data, train_labels, test_data):
324
 
325
  return torch.tensor(predictions) # Return predictions as a PyTorch tensor
326
 
327
- def plot_confusion_matrix(y_true, y_pred, title, save_path="confusion_matrix.png", light_mode=True):
328
  cm = confusion_matrix(y_true, y_pred)
329
-
330
  # Calculate F1 Score
331
  f1 = f1_score(y_true, y_pred, average='weighted')
332
 
333
- # Choose the color scheme based on the mode
334
- if light_mode:
335
- plt.style.use('default') # Light mode styling
336
- text_color = 'black'
337
- cmap = 'Blues' # Light-mode-friendly colormap
338
- else:
339
- plt.style.use('dark_background') # Dark mode styling
340
- text_color = 'white'
341
- cmap = 'magma' # Dark-mode-friendly colormap
342
-
343
  plt.figure(figsize=(5, 5))
344
-
345
- # Plot the confusion matrix with a colormap compatible for the mode
346
- sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12},
347
- linewidths=0.5, linecolor='white')
348
-
349
  # Add F1-score to the title
350
- plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=14)
351
-
352
- # Customize tick labels for light/dark mode
353
- plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
354
- plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=10)
355
-
356
- plt.ylabel('True label', color=text_color, fontsize=12)
357
- plt.xlabel('Predicted label', color=text_color, fontsize=12)
358
  plt.tight_layout()
359
-
360
- # Save the plot as an image, ensure the image is fully written
361
- plt.savefig(save_path, bbox_inches='tight', transparent=True)
362
  plt.close()
363
-
364
- # Check if the file exists and can be opened
365
- if not os.path.exists(save_path):
366
- raise FileNotFoundError(f"File {save_path} not found.")
367
-
368
  # Return the saved image
369
- return Image.open(save_path)
 
370
 
371
  def identical_train_test_split(output_emb, output_raw, labels, train_percentage):
372
  N = output_emb.shape[0]
@@ -574,7 +561,7 @@ with gr.Blocks(css="""
574
  # Add a conclusion section at the bottom
575
  gr.Markdown("""
576
  <div class="explanation-box">
577
- **Conclusions**: By adjusting the data percentage and task complexity, you can observe how the LWM generalizes well to unseen scenarios and handles various complexities in the beam prediction task.
578
  </div>
579
  """)
580
 
@@ -636,7 +623,7 @@ with gr.Blocks(css="""
636
  # Add a conclusion section at the bottom
637
  gr.Markdown("""
638
  <div class="explanation-box">
639
- **Conclusions**: With this task, you can evaluate how well LWM embeddings perform on LoS/NLoS classification tasks, and compare it to the performance of raw channels in identifying these features.
640
  </div>
641
  """)
642
 
 
324
 
325
  return torch.tensor(predictions) # Return predictions as a PyTorch tensor
326
 
327
+ def plot_confusion_matrix(y_true, y_pred, title):
328
  cm = confusion_matrix(y_true, y_pred)
329
+
330
  # Calculate F1 Score
331
  f1 = f1_score(y_true, y_pred, average='weighted')
332
 
333
+ plt.style.use('dark_background')
 
 
 
 
 
 
 
 
 
334
  plt.figure(figsize=(5, 5))
335
+
336
+ # Plot the confusion matrix with a dark-mode compatible colormap
337
+ sns.heatmap(cm, annot=True, fmt="d", cmap="magma", cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
338
+
 
339
  # Add F1-score to the title
340
+ plt.title(f"{title}\n(F1 Score: {f1:.3f})", color="white", fontsize=14)
341
+
342
+ # Customize tick labels for dark mode
343
+ plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
344
+ plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color="white", fontsize=10)
345
+
346
+ plt.ylabel('True label', color="white", fontsize=12)
347
+ plt.xlabel('Predicted label', color="white", fontsize=12)
348
  plt.tight_layout()
349
+
350
+ # Save the plot as an image
351
+ plt.savefig(f"{title}.png", transparent=True) # Use transparent to blend with the dark mode website
352
  plt.close()
353
+
 
 
 
 
354
  # Return the saved image
355
+ return Image.open(f"{title}.png")
356
+
357
 
358
  def identical_train_test_split(output_emb, output_raw, labels, train_percentage):
359
  N = output_emb.shape[0]
 
561
  # Add a conclusion section at the bottom
562
  gr.Markdown("""
563
  <div class="explanation-box">
564
+ <b>Conclusions<b>: LWM embeddings offer such high generalization that with just a limited number of training samples, we can get high performances.
565
  </div>
566
  """)
567
 
 
623
  # Add a conclusion section at the bottom
624
  gr.Markdown("""
625
  <div class="explanation-box">
626
+ <b>Conclusions<b>: LWM CLS embeddings, although very small (raw channels size / 32), offer a rich and holistic knowledge about channels, making them suitable for a task like LoS/NLoS classfication, specifically with very limited data.
627
  </div>
628
  """)
629