glenn-jocher commited on
Commit
61ea23c
1 Parent(s): 73a92dc

Implement `@torch.no_grad()` decorator (#3312)

Browse files

* `@torch.no_grad()` decorator

* Update detect.py

Files changed (2) hide show
  1. detect.py +6 -6
  2. test.py +16 -16
detect.py CHANGED
@@ -14,6 +14,7 @@ from utils.plots import colors, plot_one_box
14
  from utils.torch_utils import select_device, load_classifier, time_synchronized
15
 
16
 
 
17
  def detect(opt):
18
  source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
19
  save_img = not opt.nosave and not source.endswith('.txt') # save inference images
@@ -175,10 +176,9 @@ if __name__ == '__main__':
175
  print(opt)
176
  check_requirements(exclude=('tensorboard', 'pycocotools', 'thop'))
177
 
178
- with torch.no_grad():
179
- if opt.update: # update all models (to fix SourceChangeWarning)
180
- for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
181
- detect(opt=opt)
182
- strip_optimizer(opt.weights)
183
- else:
184
  detect(opt=opt)
 
 
 
 
14
  from utils.torch_utils import select_device, load_classifier, time_synchronized
15
 
16
 
17
+ @torch.no_grad()
18
  def detect(opt):
19
  source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
20
  save_img = not opt.nosave and not source.endswith('.txt') # save inference images
 
176
  print(opt)
177
  check_requirements(exclude=('tensorboard', 'pycocotools', 'thop'))
178
 
179
+ if opt.update: # update all models (to fix SourceChangeWarning)
180
+ for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
 
 
 
 
181
  detect(opt=opt)
182
+ strip_optimizer(opt.weights)
183
+ else:
184
+ detect(opt=opt)
test.py CHANGED
@@ -18,6 +18,7 @@ from utils.plots import plot_images, output_to_target, plot_study_txt
18
  from utils.torch_utils import select_device, time_synchronized
19
 
20
 
 
21
  def test(data,
22
  weights=None,
23
  batch_size=32,
@@ -105,22 +106,21 @@ def test(data,
105
  targets = targets.to(device)
106
  nb, _, height, width = img.shape # batch size, channels, height, width
107
 
108
- with torch.no_grad():
109
- # Run model
110
- t = time_synchronized()
111
- out, train_out = model(img, augment=augment) # inference and training outputs
112
- t0 += time_synchronized() - t
113
-
114
- # Compute loss
115
- if compute_loss:
116
- loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls
117
-
118
- # Run NMS
119
- targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
120
- lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
121
- t = time_synchronized()
122
- out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
123
- t1 += time_synchronized() - t
124
 
125
  # Statistics per image
126
  for si, pred in enumerate(out):
 
18
  from utils.torch_utils import select_device, time_synchronized
19
 
20
 
21
+ @torch.no_grad()
22
  def test(data,
23
  weights=None,
24
  batch_size=32,
 
106
  targets = targets.to(device)
107
  nb, _, height, width = img.shape # batch size, channels, height, width
108
 
109
+ # Run model
110
+ t = time_synchronized()
111
+ out, train_out = model(img, augment=augment) # inference and training outputs
112
+ t0 += time_synchronized() - t
113
+
114
+ # Compute loss
115
+ if compute_loss:
116
+ loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls
117
+
118
+ # Run NMS
119
+ targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
120
+ lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
121
+ t = time_synchronized()
122
+ out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
123
+ t1 += time_synchronized() - t
 
124
 
125
  # Statistics per image
126
  for si, pred in enumerate(out):