glenn-jocher
commited on
Commit
•
095197b
1
Parent(s):
4695ca8
Ignore Seaborn plot warnings (#3576)
Browse files* Ignore Seaborn plot warnings
* Update plots.py
* Update metrics.py
- utils/metrics.py +6 -3
- utils/plots.py +4 -4
utils/metrics.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# Model validation metrics
|
2 |
|
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
import matplotlib.pyplot as plt
|
@@ -167,9 +168,11 @@ class ConfusionMatrix:
|
|
167 |
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|
168 |
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
|
169 |
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
173 |
fig.axes[0].set_xlabel('True')
|
174 |
fig.axes[0].set_ylabel('Predicted')
|
175 |
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
|
|
1 |
# Model validation metrics
|
2 |
|
3 |
+
import warnings
|
4 |
from pathlib import Path
|
5 |
|
6 |
import matplotlib.pyplot as plt
|
|
|
168 |
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|
169 |
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
|
170 |
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
|
171 |
+
with warnings.catch_warnings():
|
172 |
+
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
173 |
+
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
|
174 |
+
xticklabels=names + ['background FP'] if labels else "auto",
|
175 |
+
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
|
176 |
fig.axes[0].set_xlabel('True')
|
177 |
fig.axes[0].set_ylabel('Predicted')
|
178 |
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
utils/plots.py
CHANGED
@@ -11,7 +11,7 @@ import matplotlib
|
|
11 |
import matplotlib.pyplot as plt
|
12 |
import numpy as np
|
13 |
import pandas as pd
|
14 |
-
import seaborn as
|
15 |
import torch
|
16 |
import yaml
|
17 |
from PIL import Image, ImageDraw, ImageFont
|
@@ -291,7 +291,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
|
|
291 |
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
292 |
|
293 |
# seaborn correlogram
|
294 |
-
|
295 |
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
296 |
plt.close()
|
297 |
|
@@ -306,8 +306,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
|
|
306 |
ax[0].set_xticklabels(names, rotation=90, fontsize=10)
|
307 |
else:
|
308 |
ax[0].set_xlabel('classes')
|
309 |
-
|
310 |
-
|
311 |
|
312 |
# rectangles
|
313 |
labels[:, 1:3] = 0.5 # center
|
|
|
11 |
import matplotlib.pyplot as plt
|
12 |
import numpy as np
|
13 |
import pandas as pd
|
14 |
+
import seaborn as sn
|
15 |
import torch
|
16 |
import yaml
|
17 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
291 |
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
292 |
|
293 |
# seaborn correlogram
|
294 |
+
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
295 |
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
296 |
plt.close()
|
297 |
|
|
|
306 |
ax[0].set_xticklabels(names, rotation=90, fontsize=10)
|
307 |
else:
|
308 |
ax[0].set_xlabel('classes')
|
309 |
+
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
310 |
+
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
311 |
|
312 |
# rectangles
|
313 |
labels[:, 1:3] = 0.5 # center
|