glenn-jocher
commited on
Commit
•
248504c
1
Parent(s):
dabad57
Feature visualization improvements 32 (#3947)
Browse files- detect.py +1 -1
- utils/plots.py +6 -5
detect.py
CHANGED
@@ -103,7 +103,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
|
|
103 |
t1 = time_synchronized()
|
104 |
pred = model(img,
|
105 |
augment=augment,
|
106 |
-
visualize=increment_path(save_dir /
|
107 |
|
108 |
# Apply NMS
|
109 |
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
|
|
103 |
t1 = time_synchronized()
|
104 |
pred = model(img,
|
105 |
augment=augment,
|
106 |
+
visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0]
|
107 |
|
108 |
# Apply NMS
|
109 |
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
utils/plots.py
CHANGED
@@ -16,7 +16,7 @@ import torch
|
|
16 |
import yaml
|
17 |
from PIL import Image, ImageDraw, ImageFont
|
18 |
|
19 |
-
from utils.general import
|
20 |
from utils.metrics import fitness
|
21 |
|
22 |
# Settings
|
@@ -447,7 +447,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
|
|
447 |
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
|
448 |
|
449 |
|
450 |
-
def feature_visualization(x, module_type, stage, n=
|
451 |
"""
|
452 |
x: Features to be visualized
|
453 |
module_type: Module type
|
@@ -460,13 +460,14 @@ def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detec
|
|
460 |
if height > 1 and width > 1:
|
461 |
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
462 |
|
463 |
-
plt.figure(tight_layout=True)
|
464 |
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
465 |
n = min(n, channels) # number of plots
|
466 |
-
ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)
|
|
|
|
|
467 |
for i in range(n):
|
468 |
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
469 |
ax[i].axis('off')
|
470 |
|
471 |
print(f'Saving {save_dir / f}... ({n}/{channels})')
|
472 |
-
plt.savefig(save_dir / f, dpi=300)
|
|
|
16 |
import yaml
|
17 |
from PIL import Image, ImageDraw, ImageFont
|
18 |
|
19 |
+
from utils.general import xywh2xyxy, xyxy2xywh
|
20 |
from utils.metrics import fitness
|
21 |
|
22 |
# Settings
|
|
|
447 |
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
|
448 |
|
449 |
|
450 |
+
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
|
451 |
"""
|
452 |
x: Features to be visualized
|
453 |
module_type: Module type
|
|
|
460 |
if height > 1 and width > 1:
|
461 |
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
462 |
|
|
|
463 |
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
464 |
n = min(n, channels) # number of plots
|
465 |
+
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
|
466 |
+
ax = ax.ravel()
|
467 |
+
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
468 |
for i in range(n):
|
469 |
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
470 |
ax[i].axis('off')
|
471 |
|
472 |
print(f'Saving {save_dir / f}... ({n}/{channels})')
|
473 |
+
plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')
|