Update `feature_visualization()` (#3807)
Browse files* Update `feature_visualization()`
Only plot for data with height, width > 1
* cleanup
* Cleanup
- utils/plots.py +21 -19
utils/plots.py
CHANGED
@@ -448,26 +448,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
|
|
448 |
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
|
449 |
|
450 |
|
451 |
-
def feature_visualization(
|
452 |
"""
|
453 |
-
|
454 |
module_type: Module type
|
455 |
-
|
456 |
n: Maximum number of feature maps to plot
|
457 |
"""
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
|
|
|
|
|
448 |
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
|
449 |
|
450 |
|
451 |
+
def feature_visualization(x, module_type, stage, n=64):
|
452 |
"""
|
453 |
+
x: Features to be visualized
|
454 |
module_type: Module type
|
455 |
+
stage: Module stage within model
|
456 |
n: Maximum number of feature maps to plot
|
457 |
"""
|
458 |
+
batch, channels, height, width = x.shape # batch, channels, height, width
|
459 |
+
if height > 1 and width > 1:
|
460 |
+
project, name = 'runs/features', 'exp'
|
461 |
+
save_dir = increment_path(Path(project) / name) # increment run
|
462 |
+
save_dir.mkdir(parents=True, exist_ok=True) # make dir
|
463 |
+
|
464 |
+
plt.figure(tight_layout=True)
|
465 |
+
blocks = torch.chunk(x, channels, dim=1) # block by channel dimension
|
466 |
+
n = min(n, len(blocks))
|
467 |
+
for i in range(n):
|
468 |
+
feature = transforms.ToPILImage()(blocks[i].squeeze())
|
469 |
+
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
|
470 |
+
ax.axis('off')
|
471 |
+
plt.imshow(feature) # cmap='gray'
|
472 |
+
|
473 |
+
f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
|
474 |
+
print(f'Saving {save_dir / f}...')
|
475 |
+
plt.savefig(save_dir / f, dpi=300)
|