Sadjad Alikhani commited on
Commit
ceb94a2
·
verified ·
1 Parent(s): 90c33f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -31,7 +31,7 @@ def beam_prediction_task(data_percentage, task_complexity):
31
  raw_cm = compute_average_confusion_matrix(raw_folder)
32
  if raw_cm is not None:
33
  raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
34
- plot_confusion_matrix(raw_cm, classes=np.arange(raw_cm.shape[0]), title=f"Raw Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=raw_cm_path)
35
  raw_img = Image.open(raw_cm_path)
36
  else:
37
  raw_img = None
@@ -40,14 +40,28 @@ def beam_prediction_task(data_percentage, task_complexity):
40
  embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
41
  if embeddings_cm is not None:
42
  embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
43
- plot_confusion_matrix(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=embeddings_cm_path)
44
  embeddings_img = Image.open(embeddings_cm_path)
45
  else:
46
  embeddings_img = None
47
 
48
  return raw_img, embeddings_img
49
 
 
 
 
 
 
 
 
 
50
 
 
 
 
 
 
 
51
  # Function to compute the average confusion matrix across CSV files in a folder
52
  def compute_average_confusion_matrix(folder):
53
  confusion_matrices = []
 
31
  raw_cm = compute_average_confusion_matrix(raw_folder)
32
  if raw_cm is not None:
33
  raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
34
+ plot_confusion_matrix_beamPred(raw_cm, classes=np.arange(raw_cm.shape[0]), title=f"Raw Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=raw_cm_path)
35
  raw_img = Image.open(raw_cm_path)
36
  else:
37
  raw_img = None
 
40
  embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
41
  if embeddings_cm is not None:
42
  embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
43
+ plot_confusion_matrix_beamPred(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=embeddings_cm_path)
44
  embeddings_img = Image.open(embeddings_cm_path)
45
  else:
46
  embeddings_img = None
47
 
48
  return raw_img, embeddings_img
49
 
50
+ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
51
+ plt.figure(figsize=(8, 6))
52
+ plt.imshow(cm, interpolation='nearest', cmap='coolwarm')
53
+ plt.title(title)
54
+ plt.colorbar()
55
+ tick_marks = np.arange(len(classes))
56
+ plt.xticks(tick_marks, classes, rotation=45)
57
+ plt.yticks(tick_marks, classes)
58
 
59
+ plt.tight_layout()
60
+ plt.ylabel('True label')
61
+ plt.xlabel('Predicted label')
62
+ plt.savefig(save_path)
63
+ plt.close()
64
+
65
  # Function to compute the average confusion matrix across CSV files in a folder
66
  def compute_average_confusion_matrix(folder):
67
  confusion_matrices = []