Zigarss glenn-jocher commited on
Commit
20d45aa
1 Parent(s): 3974d72

Add feature map visualization (#3804)

Browse files

* Add feature map visualization

Add a feature_visualization function to visualize the mid feature map of the model.

* Update yolo.py

* remove boolean from forward and reorder if statement

* remove print from forward

* General cleanup

* Indent

* Update plots.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (2) hide show
  1. models/yolo.py +5 -1
  2. utils/plots.py +28 -2
models/yolo.py CHANGED
@@ -17,6 +17,7 @@ from models.common import *
17
  from models.experimental import *
18
  from utils.autoanchor import check_anchor_order
19
  from utils.general import make_divisible, check_file, set_logging
 
20
  from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
21
  select_device, copy_attr
22
 
@@ -135,7 +136,7 @@ class Model(nn.Module):
135
  y.append(yi)
136
  return torch.cat(y, 1), None # augmented inference, train
137
 
138
- def forward_once(self, x, profile=False):
139
  y, dt = [], [] # outputs
140
  for m in self.model:
141
  if m.f != -1: # if not from previous layer
@@ -153,6 +154,9 @@ class Model(nn.Module):
153
 
154
  x = m(x) # run
155
  y.append(x if m.i in self.save else None) # save output
 
 
 
156
 
157
  if profile:
158
  logger.info('%.1fms total' % sum(dt))
 
17
  from models.experimental import *
18
  from utils.autoanchor import check_anchor_order
19
  from utils.general import make_divisible, check_file, set_logging
20
+ from utils.plots import feature_visualization
21
  from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
22
  select_device, copy_attr
23
 
 
136
  y.append(yi)
137
  return torch.cat(y, 1), None # augmented inference, train
138
 
139
+ def forward_once(self, x, profile=False, feature_vis=False):
140
  y, dt = [], [] # outputs
141
  for m in self.model:
142
  if m.f != -1: # if not from previous layer
 
154
 
155
  x = m(x) # run
156
  y.append(x if m.i in self.save else None) # save output
157
+
158
+ if feature_vis and m.type == 'models.common.SPP':
159
+ feature_visualization(x, m.type, m.i)
160
 
161
  if profile:
162
  logger.info('%.1fms total' % sum(dt))
utils/plots.py CHANGED
@@ -15,8 +15,9 @@ import seaborn as sn
15
  import torch
16
  import yaml
17
  from PIL import Image, ImageDraw, ImageFont
 
18
 
19
- from utils.general import xywh2xyxy, xyxy2xywh
20
  from utils.metrics import fitness
21
 
22
  # Settings
@@ -299,7 +300,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
299
  matplotlib.use('svg') # faster
300
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
301
  y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
302
- # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
303
  ax[0].set_ylabel('instances')
304
  if 0 < len(names) < 30:
305
  ax[0].set_xticks(range(len(names)))
@@ -445,3 +446,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
445
 
446
  ax[1].legend()
447
  fig.savefig(Path(save_dir) / 'results.png', dpi=200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  import torch
16
  import yaml
17
  from PIL import Image, ImageDraw, ImageFont
18
+ from torchvision import transforms
19
 
20
+ from utils.general import increment_path, xywh2xyxy, xyxy2xywh
21
  from utils.metrics import fitness
22
 
23
  # Settings
 
300
  matplotlib.use('svg') # faster
301
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
302
  y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
303
+ # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
304
  ax[0].set_ylabel('instances')
305
  if 0 < len(names) < 30:
306
  ax[0].set_xticks(range(len(names)))
 
446
 
447
  ax[1].legend()
448
  fig.savefig(Path(save_dir) / 'results.png', dpi=200)
449
+
450
+
451
+ def feature_visualization(features, module_type, module_idx, n=64):
452
+ """
453
+ features: Features to be visualized
454
+ module_type: Module type
455
+ module_idx: Module layer index within model
456
+ n: Maximum number of feature maps to plot
457
+ """
458
+ project, name = 'runs/features', 'exp'
459
+ save_dir = increment_path(Path(project) / name) # increment run
460
+ save_dir.mkdir(parents=True, exist_ok=True) # make dir
461
+
462
+ plt.figure(tight_layout=True)
463
+ blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension
464
+ n = min(n, len(blocks))
465
+ for i in range(n):
466
+ feature = transforms.ToPILImage()(blocks[i].squeeze())
467
+ ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
468
+ ax.axis('off')
469
+ plt.imshow(feature) # cmap='gray'
470
+
471
+ f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png"
472
+ print(f'Saving {save_dir / f}...')
473
+ plt.savefig(save_dir / f, dpi=300)