Sadjad Alikhani commited on
Commit
fc197f1
·
verified ·
1 Parent(s): c95d871

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -2
app.py CHANGED
@@ -63,16 +63,52 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
63
  plt.close()
64
 
65
  # Function to compute the average confusion matrix across CSV files in a folder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def compute_average_confusion_matrix(folder):
67
  confusion_matrices = []
 
 
 
 
 
 
 
 
 
 
68
  for file in os.listdir(folder):
69
  if file.endswith(".csv"):
70
  data = pd.read_csv(os.path.join(folder, file))
71
  y_true = data["Target"]
72
  y_pred = data["Top-1 Prediction"]
73
  num_labels = len(np.unique(y_true))
74
- cm = confusion_matrix(y_true, y_pred, labels=np.arange(num_labels))
75
- confusion_matrices.append(cm)
 
 
 
 
 
 
 
 
 
76
 
77
  if confusion_matrices:
78
  avg_cm = np.mean(confusion_matrices, axis=0)
@@ -84,6 +120,7 @@ def compute_average_confusion_matrix(folder):
84
 
85
 
86
 
 
87
  # Custom class to capture print output
88
  class PrintCapture(io.StringIO):
89
  def __init__(self):
 
63
  plt.close()
64
 
65
  # Function to compute the average confusion matrix across CSV files in a folder
66
+ #def compute_average_confusion_matrix(folder):
67
+ # confusion_matrices = []
68
+ # for file in os.listdir(folder):
69
+ # if file.endswith(".csv"):
70
+ # data = pd.read_csv(os.path.join(folder, file))
71
+ # y_true = data["Target"]
72
+ # y_pred = data["Top-1 Prediction"]
73
+ # num_labels = len(np.unique(y_true))
74
+ # cm = confusion_matrix(y_true, y_pred, labels=np.arange(num_labels))
75
+ # confusion_matrices.append(cm)
76
+ #
77
+ # if confusion_matrices:
78
+ # avg_cm = np.mean(confusion_matrices, axis=0)
79
+ # return avg_cm
80
+ # else:
81
+ # return None
82
+
83
  def compute_average_confusion_matrix(folder):
84
  confusion_matrices = []
85
+ max_num_labels = 0
86
+
87
+ # First pass to determine the maximum number of labels
88
+ for file in os.listdir(folder):
89
+ if file.endswith(".csv"):
90
+ data = pd.read_csv(os.path.join(folder, file))
91
+ num_labels = len(np.unique(data["Target"]))
92
+ max_num_labels = max(max_num_labels, num_labels)
93
+
94
+ # Second pass to calculate the confusion matrices and pad if necessary
95
  for file in os.listdir(folder):
96
  if file.endswith(".csv"):
97
  data = pd.read_csv(os.path.join(folder, file))
98
  y_true = data["Target"]
99
  y_pred = data["Top-1 Prediction"]
100
  num_labels = len(np.unique(y_true))
101
+
102
+ # Compute confusion matrix
103
+ cm = confusion_matrix(y_true, y_pred, labels=np.arange(max_num_labels))
104
+
105
+ # If the confusion matrix is smaller, pad it to match the largest size
106
+ if cm.shape[0] < max_num_labels:
107
+ padded_cm = np.zeros((max_num_labels, max_num_labels))
108
+ padded_cm[:cm.shape[0], :cm.shape[1]] = cm
109
+ confusion_matrices.append(padded_cm)
110
+ else:
111
+ confusion_matrices.append(cm)
112
 
113
  if confusion_matrices:
114
  avg_cm = np.mean(confusion_matrices, axis=0)
 
120
 
121
 
122
 
123
+
124
  # Custom class to capture print output
125
  class PrintCapture(io.StringIO):
126
  def __init__(self):