glenn-jocher commited on
Commit
02719dd
·
unverified ·
1 Parent(s): 20d45aa

Update `feature_visualization()` (#3807)

Browse files

* Update `feature_visualization()`

Only plot for data with height, width > 1

* cleanup

* Cleanup

Files changed (1) hide show
  1. utils/plots.py +21 -19
utils/plots.py CHANGED
@@ -448,26 +448,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
448
  fig.savefig(Path(save_dir) / 'results.png', dpi=200)
449
 
450
 
451
- def feature_visualization(features, module_type, module_idx, n=64):
452
  """
453
- features: Features to be visualized
454
  module_type: Module type
455
- module_idx: Module layer index within model
456
  n: Maximum number of feature maps to plot
457
  """
458
- project, name = 'runs/features', 'exp'
459
- save_dir = increment_path(Path(project) / name) # increment run
460
- save_dir.mkdir(parents=True, exist_ok=True) # make dir
461
-
462
- plt.figure(tight_layout=True)
463
- blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension
464
- n = min(n, len(blocks))
465
- for i in range(n):
466
- feature = transforms.ToPILImage()(blocks[i].squeeze())
467
- ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
468
- ax.axis('off')
469
- plt.imshow(feature) # cmap='gray'
470
-
471
- f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png"
472
- print(f'Saving {save_dir / f}...')
473
- plt.savefig(save_dir / f, dpi=300)
 
 
 
448
  fig.savefig(Path(save_dir) / 'results.png', dpi=200)
449
 
450
 
451
+ def feature_visualization(x, module_type, stage, n=64):
452
  """
453
+ x: Features to be visualized
454
  module_type: Module type
455
+ stage: Module stage within model
456
  n: Maximum number of feature maps to plot
457
  """
458
+ batch, channels, height, width = x.shape # batch, channels, height, width
459
+ if height > 1 and width > 1:
460
+ project, name = 'runs/features', 'exp'
461
+ save_dir = increment_path(Path(project) / name) # increment run
462
+ save_dir.mkdir(parents=True, exist_ok=True) # make dir
463
+
464
+ plt.figure(tight_layout=True)
465
+ blocks = torch.chunk(x, channels, dim=1) # block by channel dimension
466
+ n = min(n, len(blocks))
467
+ for i in range(n):
468
+ feature = transforms.ToPILImage()(blocks[i].squeeze())
469
+ ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
470
+ ax.axis('off')
471
+ plt.imshow(feature) # cmap='gray'
472
+
473
+ f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
474
+ print(f'Saving {save_dir / f}...')
475
+ plt.savefig(save_dir / f, dpi=300)