glenn-jocher commited on
Commit
248504c
1 Parent(s): dabad57

Feature visualization improvements 32 (#3947)

Browse files
Files changed (2) hide show
  1. detect.py +1 -1
  2. utils/plots.py +6 -5
detect.py CHANGED
@@ -103,7 +103,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
103
  t1 = time_synchronized()
104
  pred = model(img,
105
  augment=augment,
106
- visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0]
107
 
108
  # Apply NMS
109
  pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
 
103
  t1 = time_synchronized()
104
  pred = model(img,
105
  augment=augment,
106
+ visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0]
107
 
108
  # Apply NMS
109
  pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
utils/plots.py CHANGED
@@ -16,7 +16,7 @@ import torch
16
  import yaml
17
  from PIL import Image, ImageDraw, ImageFont
18
 
19
- from utils.general import increment_path, xywh2xyxy, xyxy2xywh
20
  from utils.metrics import fitness
21
 
22
  # Settings
@@ -447,7 +447,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
447
  fig.savefig(Path(save_dir) / 'results.png', dpi=200)
448
 
449
 
450
- def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')):
451
  """
452
  x: Features to be visualized
453
  module_type: Module type
@@ -460,13 +460,14 @@ def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detec
460
  if height > 1 and width > 1:
461
  f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
462
 
463
- plt.figure(tight_layout=True)
464
  blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
465
  n = min(n, channels) # number of plots
466
- ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols
 
 
467
  for i in range(n):
468
  ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
469
  ax[i].axis('off')
470
 
471
  print(f'Saving {save_dir / f}... ({n}/{channels})')
472
- plt.savefig(save_dir / f, dpi=300)
 
16
  import yaml
17
  from PIL import Image, ImageDraw, ImageFont
18
 
19
+ from utils.general import xywh2xyxy, xyxy2xywh
20
  from utils.metrics import fitness
21
 
22
  # Settings
 
447
  fig.savefig(Path(save_dir) / 'results.png', dpi=200)
448
 
449
 
450
+ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
451
  """
452
  x: Features to be visualized
453
  module_type: Module type
 
460
  if height > 1 and width > 1:
461
  f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
462
 
 
463
  blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
464
  n = min(n, channels) # number of plots
465
+ fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
466
+ ax = ax.ravel()
467
+ plt.subplots_adjust(wspace=0.05, hspace=0.05)
468
  for i in range(n):
469
  ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
470
  ax[i].axis('off')
471
 
472
  print(f'Saving {save_dir / f}... ({n}/{channels})')
473
+ plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')