Spaces:
Running
Running
Sadjad Alikhani
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -31,7 +31,7 @@ def beam_prediction_task(data_percentage, task_complexity):
|
|
31 |
raw_cm = compute_average_confusion_matrix(raw_folder)
|
32 |
if raw_cm is not None:
|
33 |
raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
34 |
-
|
35 |
raw_img = Image.open(raw_cm_path)
|
36 |
else:
|
37 |
raw_img = None
|
@@ -40,14 +40,28 @@ def beam_prediction_task(data_percentage, task_complexity):
|
|
40 |
embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
|
41 |
if embeddings_cm is not None:
|
42 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
43 |
-
|
44 |
embeddings_img = Image.open(embeddings_cm_path)
|
45 |
else:
|
46 |
embeddings_img = None
|
47 |
|
48 |
return raw_img, embeddings_img
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# Function to compute the average confusion matrix across CSV files in a folder
|
52 |
def compute_average_confusion_matrix(folder):
|
53 |
confusion_matrices = []
|
|
|
31 |
raw_cm = compute_average_confusion_matrix(raw_folder)
|
32 |
if raw_cm is not None:
|
33 |
raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
34 |
+
plot_confusion_matrix_beamPred(raw_cm, classes=np.arange(raw_cm.shape[0]), title=f"Raw Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=raw_cm_path)
|
35 |
raw_img = Image.open(raw_cm_path)
|
36 |
else:
|
37 |
raw_img = None
|
|
|
40 |
embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
|
41 |
if embeddings_cm is not None:
|
42 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
43 |
+
plot_confusion_matrix_beamPred(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=embeddings_cm_path)
|
44 |
embeddings_img = Image.open(embeddings_cm_path)
|
45 |
else:
|
46 |
embeddings_img = None
|
47 |
|
48 |
return raw_img, embeddings_img
|
49 |
|
50 |
+
def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
51 |
+
plt.figure(figsize=(8, 6))
|
52 |
+
plt.imshow(cm, interpolation='nearest', cmap='coolwarm')
|
53 |
+
plt.title(title)
|
54 |
+
plt.colorbar()
|
55 |
+
tick_marks = np.arange(len(classes))
|
56 |
+
plt.xticks(tick_marks, classes, rotation=45)
|
57 |
+
plt.yticks(tick_marks, classes)
|
58 |
|
59 |
+
plt.tight_layout()
|
60 |
+
plt.ylabel('True label')
|
61 |
+
plt.xlabel('Predicted label')
|
62 |
+
plt.savefig(save_path)
|
63 |
+
plt.close()
|
64 |
+
|
65 |
# Function to compute the average confusion matrix across CSV files in a folder
|
66 |
def compute_average_confusion_matrix(folder):
|
67 |
confusion_matrices = []
|