glenn-jocher commited on
Commit
e8c5237
1 Parent(s): ec2da4a

ConfusionMatrix `normalize=True` fix (#3587)

Browse files
Files changed (1) hide show
  1. utils/metrics.py +3 -4
utils/metrics.py CHANGED
@@ -161,9 +161,8 @@ class ConfusionMatrix:
161
  def plot(self, normalize=True, save_dir='', names=()):
162
  try:
163
  import seaborn as sn
164
-
165
- if normalize:
166
- array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize columns
167
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
168
 
169
  fig = plt.figure(figsize=(12, 9), tight_layout=True)
@@ -178,7 +177,7 @@ class ConfusionMatrix:
178
  fig.axes[0].set_ylabel('Predicted')
179
  fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
180
  except Exception as e:
181
- pass
182
 
183
  def print(self):
184
  for i in range(self.nc + 1):
 
161
  def plot(self, normalize=True, save_dir='', names=()):
162
  try:
163
  import seaborn as sn
164
+
165
+ array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-6) if normalize else 1) # normalize columns
 
166
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
167
 
168
  fig = plt.figure(figsize=(12, 9), tight_layout=True)
 
177
  fig.axes[0].set_ylabel('Predicted')
178
  fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
179
  except Exception as e:
180
+ print(f'WARNING: ConfusionMatrix plot failure: {e}')
181
 
182
  def print(self):
183
  for i in range(self.nc + 1):