jluntamazon glenn-jocher commited on
Commit
41f5cc5
·
unverified ·
1 Parent(s): 955eea8

YOLOv5 AWS Inferentia Inplace compatibility updates (#2953)

Browse files

* Added flag to enable/disable all inplace and assignment operations

* Removed shape print statements

* Scope Detect/Model import to avoid circular dependency

* PEP8

* create _descale_pred()

* replace lost space

* replace list with tuple

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (2) hide show
  1. models/experimental.py +5 -3
  2. models/yolo.py +42 -18
models/experimental.py CHANGED
@@ -110,7 +110,9 @@ class Ensemble(nn.ModuleList):
110
  return y, None # inference, train output
111
 
112
 
113
- def attempt_load(weights, map_location=None):
 
 
114
  # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
115
  model = Ensemble()
116
  for w in weights if isinstance(weights, list) else [weights]:
@@ -120,8 +122,8 @@ def attempt_load(weights, map_location=None):
120
 
121
  # Compatibility updates
122
  for m in model.modules():
123
- if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
124
- m.inplace = True # pytorch 1.7.0 compatibility
125
  elif type(m) is Conv:
126
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
127
 
 
110
  return y, None # inference, train output
111
 
112
 
113
+ def attempt_load(weights, map_location=None, inplace=True):
114
+ from models.yolo import Detect, Model
115
+
116
  # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
117
  model = Ensemble()
118
  for w in weights if isinstance(weights, list) else [weights]:
 
122
 
123
  # Compatibility updates
124
  for m in model.modules():
125
+ if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
126
+ m.inplace = inplace # pytorch 1.7.0 compatibility
127
  elif type(m) is Conv:
128
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
129
 
models/yolo.py CHANGED
@@ -26,7 +26,7 @@ class Detect(nn.Module):
26
  stride = None # strides computed during build
27
  export = False # onnx export
28
 
29
- def __init__(self, nc=80, anchors=(), ch=()): # detection layer
30
  super(Detect, self).__init__()
31
  self.nc = nc # number of classes
32
  self.no = nc + 5 # number of outputs per anchor
@@ -37,6 +37,7 @@ class Detect(nn.Module):
37
  self.register_buffer('anchors', a) # shape(nl,na,2)
38
  self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
39
  self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
 
40
 
41
  def forward(self, x):
42
  # x = x.copy() # for profiling
@@ -52,8 +53,13 @@ class Detect(nn.Module):
52
  self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
53
 
54
  y = x[i].sigmoid()
55
- y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
56
- y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
 
 
 
 
 
57
  z.append(y.view(bs, -1, self.no))
58
 
59
  return x if self.training else (torch.cat(z, 1), x)
@@ -85,12 +91,14 @@ class Model(nn.Module):
85
  self.yaml['anchors'] = round(anchors) # override yaml value
86
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
87
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
 
88
  # logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
89
 
90
  # Build strides, anchors
91
  m = self.model[-1] # Detect()
92
  if isinstance(m, Detect):
93
  s = 256 # 2x min stride
 
94
  m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
95
  m.anchors /= m.stride.view(-1, 1, 1)
96
  check_anchor_order(m)
@@ -105,24 +113,23 @@ class Model(nn.Module):
105
 
106
  def forward(self, x, augment=False, profile=False):
107
  if augment:
108
- img_size = x.shape[-2:] # height, width
109
- s = [1, 0.83, 0.67] # scales
110
- f = [None, 3, None] # flips (2-ud, 3-lr)
111
- y = [] # outputs
112
- for si, fi in zip(s, f):
113
- xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
114
- yi = self.forward_once(xi)[0] # forward
115
- # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
116
- yi[..., :4] /= si # de-scale
117
- if fi == 2:
118
- yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
119
- elif fi == 3:
120
- yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
121
- y.append(yi)
122
- return torch.cat(y, 1), None # augmented inference, train
123
  else:
124
  return self.forward_once(x, profile) # single-scale inference, train
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def forward_once(self, x, profile=False):
127
  y, dt = [], [] # outputs
128
  for m in self.model:
@@ -146,6 +153,23 @@ class Model(nn.Module):
146
  logger.info('%.1fms total' % sum(dt))
147
  return x
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
150
  # https://arxiv.org/abs/1708.02002 section 3.3
151
  # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
 
26
  stride = None # strides computed during build
27
  export = False # onnx export
28
 
29
+ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
30
  super(Detect, self).__init__()
31
  self.nc = nc # number of classes
32
  self.no = nc + 5 # number of outputs per anchor
 
37
  self.register_buffer('anchors', a) # shape(nl,na,2)
38
  self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
39
  self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
40
+ self.inplace = inplace # use in-place ops (e.g. slice assignment)
41
 
42
  def forward(self, x):
43
  # x = x.copy() # for profiling
 
53
  self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
54
 
55
  y = x[i].sigmoid()
56
+ if self.inplace:
57
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
58
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
59
+ else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
60
+ xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
61
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
62
+ y = torch.cat((xy, wh, y[..., 4:]), -1)
63
  z.append(y.view(bs, -1, self.no))
64
 
65
  return x if self.training else (torch.cat(z, 1), x)
 
91
  self.yaml['anchors'] = round(anchors) # override yaml value
92
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
93
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
94
+ self.inplace = self.yaml.get('inplace', True)
95
  # logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
96
 
97
  # Build strides, anchors
98
  m = self.model[-1] # Detect()
99
  if isinstance(m, Detect):
100
  s = 256 # 2x min stride
101
+ m.inplace = self.inplace
102
  m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
103
  m.anchors /= m.stride.view(-1, 1, 1)
104
  check_anchor_order(m)
 
113
 
114
  def forward(self, x, augment=False, profile=False):
115
  if augment:
116
+ return self.forward_augment(x) # augmented inference, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  else:
118
  return self.forward_once(x, profile) # single-scale inference, train
119
 
120
+ def forward_augment(self, x):
121
+ img_size = x.shape[-2:] # height, width
122
+ s = [1, 0.83, 0.67] # scales
123
+ f = [None, 3, None] # flips (2-ud, 3-lr)
124
+ y = [] # outputs
125
+ for si, fi in zip(s, f):
126
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
127
+ yi = self.forward_once(xi)[0] # forward
128
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
129
+ yi = self._descale_pred(yi, fi, si, img_size)
130
+ y.append(yi)
131
+ return torch.cat(y, 1), None # augmented inference, train
132
+
133
  def forward_once(self, x, profile=False):
134
  y, dt = [], [] # outputs
135
  for m in self.model:
 
153
  logger.info('%.1fms total' % sum(dt))
154
  return x
155
 
156
+ def _descale_pred(self, p, flips, scale, img_size):
157
+ # de-scale predictions following augmented inference (inverse operation)
158
+ if self.inplace:
159
+ p[..., :4] /= scale # de-scale
160
+ if flips == 2:
161
+ p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
162
+ elif flips == 3:
163
+ p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
164
+ else:
165
+ x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
166
+ if flips == 2:
167
+ y = img_size[0] - y # de-flip ud
168
+ elif flips == 3:
169
+ x = img_size[1] - x # de-flip lr
170
+ p = torch.cat((x, y, wh, p[..., 4:]), -1)
171
+ return p
172
+
173
  def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
174
  # https://arxiv.org/abs/1708.02002 section 3.3
175
  # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.