glenn-jocher commited on
Commit
095197b
1 Parent(s): 4695ca8

Ignore Seaborn plot warnings (#3576)

Browse files

* Ignore Seaborn plot warnings

* Update plots.py

* Update metrics.py

Files changed (2) hide show
  1. utils/metrics.py +6 -3
  2. utils/plots.py +4 -4
utils/metrics.py CHANGED
@@ -1,5 +1,6 @@
1
  # Model validation metrics
2
 
 
3
  from pathlib import Path
4
 
5
  import matplotlib.pyplot as plt
@@ -167,9 +168,11 @@ class ConfusionMatrix:
167
  fig = plt.figure(figsize=(12, 9), tight_layout=True)
168
  sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
169
  labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
170
- sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
171
- xticklabels=names + ['background FP'] if labels else "auto",
172
- yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
 
 
173
  fig.axes[0].set_xlabel('True')
174
  fig.axes[0].set_ylabel('Predicted')
175
  fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
 
1
  # Model validation metrics
2
 
3
+ import warnings
4
  from pathlib import Path
5
 
6
  import matplotlib.pyplot as plt
 
168
  fig = plt.figure(figsize=(12, 9), tight_layout=True)
169
  sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
170
  labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
171
+ with warnings.catch_warnings():
172
+ warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
173
+ sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
174
+ xticklabels=names + ['background FP'] if labels else "auto",
175
+ yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
176
  fig.axes[0].set_xlabel('True')
177
  fig.axes[0].set_ylabel('Predicted')
178
  fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
utils/plots.py CHANGED
@@ -11,7 +11,7 @@ import matplotlib
11
  import matplotlib.pyplot as plt
12
  import numpy as np
13
  import pandas as pd
14
- import seaborn as sns
15
  import torch
16
  import yaml
17
  from PIL import Image, ImageDraw, ImageFont
@@ -291,7 +291,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
291
  x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
292
 
293
  # seaborn correlogram
294
- sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
295
  plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
296
  plt.close()
297
 
@@ -306,8 +306,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
306
  ax[0].set_xticklabels(names, rotation=90, fontsize=10)
307
  else:
308
  ax[0].set_xlabel('classes')
309
- sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
310
- sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
311
 
312
  # rectangles
313
  labels[:, 1:3] = 0.5 # center
 
11
  import matplotlib.pyplot as plt
12
  import numpy as np
13
  import pandas as pd
14
+ import seaborn as sn
15
  import torch
16
  import yaml
17
  from PIL import Image, ImageDraw, ImageFont
 
291
  x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
292
 
293
  # seaborn correlogram
294
+ sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
295
  plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
296
  plt.close()
297
 
 
306
  ax[0].set_xticklabels(names, rotation=90, fontsize=10)
307
  else:
308
  ax[0].set_xlabel('classes')
309
+ sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
310
+ sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
311
 
312
  # rectangles
313
  labels[:, 1:3] = 0.5 # center