glenn-jocher commited on
Commit
87b094b
1 Parent(s): 61047a2

Feature visualization update (#3920)

Browse files

* Feature visualization update

* Save to jpg (faster)

* Save to png

Files changed (3) hide show
  1. detect.py +5 -1
  2. models/yolo.py +5 -6
  3. utils/plots.py +18 -21
detect.py CHANGED
@@ -40,6 +40,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
40
  classes=None, # filter by class: --class 0, or --class 0 2 3
41
  agnostic_nms=False, # class-agnostic NMS
42
  augment=False, # augmented inference
 
43
  update=False, # update all models
44
  project='runs/detect', # save results to project/name
45
  name='exp', # save results to project/name
@@ -100,7 +101,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)
100
 
101
  # Inference
102
  t1 = time_synchronized()
103
- pred = model(img, augment=augment)[0]
 
 
104
 
105
  # Apply NMS
106
  pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
@@ -201,6 +204,7 @@ def parse_opt():
201
  parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
202
  parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
203
  parser.add_argument('--augment', action='store_true', help='augmented inference')
 
204
  parser.add_argument('--update', action='store_true', help='update all models')
205
  parser.add_argument('--project', default='runs/detect', help='save results to project/name')
206
  parser.add_argument('--name', default='exp', help='save results to project/name')
 
40
  classes=None, # filter by class: --class 0, or --class 0 2 3
41
  agnostic_nms=False, # class-agnostic NMS
42
  augment=False, # augmented inference
43
+ visualize=False, # visualize features
44
  update=False, # update all models
45
  project='runs/detect', # save results to project/name
46
  name='exp', # save results to project/name
 
101
 
102
  # Inference
103
  t1 = time_synchronized()
104
+ pred = model(img,
105
+ augment=augment,
106
+ visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0]
107
 
108
  # Apply NMS
109
  pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
 
204
  parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
205
  parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
206
  parser.add_argument('--augment', action='store_true', help='augmented inference')
207
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
208
  parser.add_argument('--update', action='store_true', help='update all models')
209
  parser.add_argument('--project', default='runs/detect', help='save results to project/name')
210
  parser.add_argument('--name', default='exp', help='save results to project/name')
models/yolo.py CHANGED
@@ -117,11 +117,10 @@ class Model(nn.Module):
117
  self.info()
118
  logger.info('')
119
 
120
- def forward(self, x, augment=False, profile=False):
121
  if augment:
122
  return self.forward_augment(x) # augmented inference, None
123
- else:
124
- return self.forward_once(x, profile) # single-scale inference, train
125
 
126
  def forward_augment(self, x):
127
  img_size = x.shape[-2:] # height, width
@@ -136,7 +135,7 @@ class Model(nn.Module):
136
  y.append(yi)
137
  return torch.cat(y, 1), None # augmented inference, train
138
 
139
- def forward_once(self, x, profile=False, feature_vis=False):
140
  y, dt = [], [] # outputs
141
  for m in self.model:
142
  if m.f != -1: # if not from previous layer
@@ -155,8 +154,8 @@ class Model(nn.Module):
155
  x = m(x) # run
156
  y.append(x if m.i in self.save else None) # save output
157
 
158
- if feature_vis and m.type == 'models.common.SPP':
159
- feature_visualization(x, m.type, m.i)
160
 
161
  if profile:
162
  logger.info('%.1fms total' % sum(dt))
 
117
  self.info()
118
  logger.info('')
119
 
120
+ def forward(self, x, augment=False, profile=False, visualize=False):
121
  if augment:
122
  return self.forward_augment(x) # augmented inference, None
123
+ return self.forward_once(x, profile, visualize) # single-scale inference, train
 
124
 
125
  def forward_augment(self, x):
126
  img_size = x.shape[-2:] # height, width
 
135
  y.append(yi)
136
  return torch.cat(y, 1), None # augmented inference, train
137
 
138
+ def forward_once(self, x, profile=False, visualize=False):
139
  y, dt = [], [] # outputs
140
  for m in self.model:
141
  if m.f != -1: # if not from previous layer
 
154
  x = m(x) # run
155
  y.append(x if m.i in self.save else None) # save output
156
 
157
+ if visualize:
158
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
159
 
160
  if profile:
161
  logger.info('%.1fms total' % sum(dt))
utils/plots.py CHANGED
@@ -1,12 +1,12 @@
1
  # Plotting utils
2
 
3
  import glob
4
- import math
5
  import os
6
  from copy import copy
7
  from pathlib import Path
8
 
9
  import cv2
 
10
  import matplotlib
11
  import matplotlib.pyplot as plt
12
  import numpy as np
@@ -15,7 +15,6 @@ import seaborn as sn
15
  import torch
16
  import yaml
17
  from PIL import Image, ImageDraw, ImageFont
18
- from torchvision import transforms
19
 
20
  from utils.general import increment_path, xywh2xyxy, xyxy2xywh
21
  from utils.metrics import fitness
@@ -448,28 +447,26 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
448
  fig.savefig(Path(save_dir) / 'results.png', dpi=200)
449
 
450
 
451
- def feature_visualization(x, module_type, stage, n=64):
452
  """
453
  x: Features to be visualized
454
  module_type: Module type
455
  stage: Module stage within model
456
  n: Maximum number of feature maps to plot
 
457
  """
458
- batch, channels, height, width = x.shape # batch, channels, height, width
459
- if height > 1 and width > 1:
460
- project, name = 'runs/features', 'exp'
461
- save_dir = increment_path(Path(project) / name) # increment run
462
- save_dir.mkdir(parents=True, exist_ok=True) # make dir
463
-
464
- plt.figure(tight_layout=True)
465
- blocks = torch.chunk(x, channels, dim=1) # block by channel dimension
466
- n = min(n, len(blocks))
467
- for i in range(n):
468
- feature = transforms.ToPILImage()(blocks[i].squeeze())
469
- ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
470
- ax.axis('off')
471
- plt.imshow(feature) # cmap='gray'
472
-
473
- f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
474
- print(f'Saving {save_dir / f}...')
475
- plt.savefig(save_dir / f, dpi=300)
 
1
  # Plotting utils
2
 
3
  import glob
 
4
  import os
5
  from copy import copy
6
  from pathlib import Path
7
 
8
  import cv2
9
+ import math
10
  import matplotlib
11
  import matplotlib.pyplot as plt
12
  import numpy as np
 
15
  import torch
16
  import yaml
17
  from PIL import Image, ImageDraw, ImageFont
 
18
 
19
  from utils.general import increment_path, xywh2xyxy, xyxy2xywh
20
  from utils.metrics import fitness
 
447
  fig.savefig(Path(save_dir) / 'results.png', dpi=200)
448
 
449
 
450
+ def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')):
451
  """
452
  x: Features to be visualized
453
  module_type: Module type
454
  stage: Module stage within model
455
  n: Maximum number of feature maps to plot
456
+ save_dir: Directory to save results
457
  """
458
+ if 'Detect' not in module_type:
459
+ batch, channels, height, width = x.shape # batch, channels, height, width
460
+ if height > 1 and width > 1:
461
+ f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
462
+
463
+ plt.figure(tight_layout=True)
464
+ blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels
465
+ n = min(n, channels) # number of plots
466
+ ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols
467
+ for i in range(n):
468
+ ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
469
+ ax[i].axis('off')
470
+
471
+ print(f'Saving {save_dir / f}... ({n}/{channels})')
472
+ plt.savefig(save_dir / f, dpi=300)