glenn-jocher commited on
Commit
b5de52c
1 Parent(s): fca5e2a

torch.cuda.amp bug fix (#2750)

Browse files

PR https://github.com/ultralytics/yolov5/pull/2725 introduced a very specific bug that only affects multi-GPU trainings. Apparently the cause was using the torch.cuda.amp decorator in the autoShape forward method. I've implemented amp more traditionally in this PR, and the bug is resolved.

Files changed (1) hide show
  1. models/common.py +13 -11
models/common.py CHANGED
@@ -10,6 +10,7 @@ import requests
10
  import torch
11
  import torch.nn as nn
12
  from PIL import Image
 
13
 
14
  from utils.datasets import letterbox
15
  from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
@@ -237,7 +238,6 @@ class autoShape(nn.Module):
237
  return self
238
 
239
  @torch.no_grad()
240
- @torch.cuda.amp.autocast(torch.cuda.is_available())
241
  def forward(self, imgs, size=640, augment=False, profile=False):
242
  # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
243
  # filename: imgs = 'data/samples/zidane.jpg'
@@ -251,7 +251,8 @@ class autoShape(nn.Module):
251
  t = [time_synchronized()]
252
  p = next(self.model.parameters()) # for device and type
253
  if isinstance(imgs, torch.Tensor): # torch
254
- return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
 
255
 
256
  # Pre-process
257
  n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
@@ -278,17 +279,18 @@ class autoShape(nn.Module):
278
  x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
279
  t.append(time_synchronized())
280
 
281
- # Inference
282
- y = self.model(x, augment, profile)[0] # forward
283
- t.append(time_synchronized())
 
284
 
285
- # Post-process
286
- y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
287
- for i in range(n):
288
- scale_coords(shape1, y[i][:, :4], shape0[i])
289
 
290
- t.append(time_synchronized())
291
- return Detections(imgs, y, files, t, self.names, x.shape)
292
 
293
 
294
  class Detections:
 
10
  import torch
11
  import torch.nn as nn
12
  from PIL import Image
13
+ from torch.cuda import amp
14
 
15
  from utils.datasets import letterbox
16
  from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
 
238
  return self
239
 
240
  @torch.no_grad()
 
241
  def forward(self, imgs, size=640, augment=False, profile=False):
242
  # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
243
  # filename: imgs = 'data/samples/zidane.jpg'
 
251
  t = [time_synchronized()]
252
  p = next(self.model.parameters()) # for device and type
253
  if isinstance(imgs, torch.Tensor): # torch
254
+ with amp.autocast(enabled=p.device.type != 'cpu'):
255
+ return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
256
 
257
  # Pre-process
258
  n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
 
279
  x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
280
  t.append(time_synchronized())
281
 
282
+ with amp.autocast(enabled=p.device.type != 'cpu'):
283
+ # Inference
284
+ y = self.model(x, augment, profile)[0] # forward
285
+ t.append(time_synchronized())
286
 
287
+ # Post-process
288
+ y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
289
+ for i in range(n):
290
+ scale_coords(shape1, y[i][:, :4], shape0[i])
291
 
292
+ t.append(time_synchronized())
293
+ return Detections(imgs, y, files, t, self.names, x.shape)
294
 
295
 
296
  class Detections: