glenn-jocher commited on
Commit
0dc725e
1 Parent(s): 621b6d5

Refactor `forward()` method profiling (#4816)

Browse files
Files changed (1) hide show
  1. models/yolo.py +19 -22
models/yolo.py CHANGED
@@ -98,7 +98,6 @@ class Model(nn.Module):
98
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
99
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
100
  self.inplace = self.yaml.get('inplace', True)
101
- # LOGGER.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
102
 
103
  # Build strides, anchors
104
  m = self.model[-1] # Detect()
@@ -110,7 +109,6 @@ class Model(nn.Module):
110
  check_anchor_order(m)
111
  self.stride = m.stride
112
  self._initialize_biases() # only run once
113
- # LOGGER.info('Strides: %s' % m.stride.tolist())
114
 
115
  # Init weights, biases
116
  initialize_weights(self)
@@ -119,47 +117,33 @@ class Model(nn.Module):
119
 
120
  def forward(self, x, augment=False, profile=False, visualize=False):
121
  if augment:
122
- return self.forward_augment(x) # augmented inference, None
123
- return self.forward_once(x, profile, visualize) # single-scale inference, train
124
 
125
- def forward_augment(self, x):
126
  img_size = x.shape[-2:] # height, width
127
  s = [1, 0.83, 0.67] # scales
128
  f = [None, 3, None] # flips (2-ud, 3-lr)
129
  y = [] # outputs
130
  for si, fi in zip(s, f):
131
  xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
132
- yi = self.forward_once(xi)[0] # forward
133
  # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
134
  yi = self._descale_pred(yi, fi, si, img_size)
135
  y.append(yi)
136
  return torch.cat(y, 1), None # augmented inference, train
137
 
138
- def forward_once(self, x, profile=False, visualize=False):
139
  y, dt = [], [] # outputs
140
  for m in self.model:
141
  if m.f != -1: # if not from previous layer
142
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
143
-
144
  if profile:
145
- c = isinstance(m, Detect) # copy input as inplace fix
146
- o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
147
- t = time_sync()
148
- for _ in range(10):
149
- m(x.copy() if c else x)
150
- dt.append((time_sync() - t) * 100)
151
- if m == self.model[0]:
152
- LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
153
- LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
154
-
155
  x = m(x) # run
156
  y.append(x if m.i in self.save else None) # save output
157
-
158
  if visualize:
159
  feature_visualization(x, m.type, m.i, save_dir=visualize)
160
-
161
- if profile:
162
- LOGGER.info('%.1fms total' % sum(dt))
163
  return x
164
 
165
  def _descale_pred(self, p, flips, scale, img_size):
@@ -179,6 +163,19 @@ class Model(nn.Module):
179
  p = torch.cat((x, y, wh, p[..., 4:]), -1)
180
  return p
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
183
  # https://arxiv.org/abs/1708.02002 section 3.3
184
  # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
 
98
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
99
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
100
  self.inplace = self.yaml.get('inplace', True)
 
101
 
102
  # Build strides, anchors
103
  m = self.model[-1] # Detect()
 
109
  check_anchor_order(m)
110
  self.stride = m.stride
111
  self._initialize_biases() # only run once
 
112
 
113
  # Init weights, biases
114
  initialize_weights(self)
 
117
 
118
  def forward(self, x, augment=False, profile=False, visualize=False):
119
  if augment:
120
+ return self._forward_augment(x) # augmented inference, None
121
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
122
 
123
+ def _forward_augment(self, x):
124
  img_size = x.shape[-2:] # height, width
125
  s = [1, 0.83, 0.67] # scales
126
  f = [None, 3, None] # flips (2-ud, 3-lr)
127
  y = [] # outputs
128
  for si, fi in zip(s, f):
129
  xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
130
+ yi = self._forward_once(xi)[0] # forward
131
  # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
132
  yi = self._descale_pred(yi, fi, si, img_size)
133
  y.append(yi)
134
  return torch.cat(y, 1), None # augmented inference, train
135
 
136
+ def _forward_once(self, x, profile=False, visualize=False):
137
  y, dt = [], [] # outputs
138
  for m in self.model:
139
  if m.f != -1: # if not from previous layer
140
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
 
141
  if profile:
142
+ self._profile_one_layer(m, x, dt)
 
 
 
 
 
 
 
 
 
143
  x = m(x) # run
144
  y.append(x if m.i in self.save else None) # save output
 
145
  if visualize:
146
  feature_visualization(x, m.type, m.i, save_dir=visualize)
 
 
 
147
  return x
148
 
149
  def _descale_pred(self, p, flips, scale, img_size):
 
163
  p = torch.cat((x, y, wh, p[..., 4:]), -1)
164
  return p
165
 
166
+ def _profile_one_layer(self, m, x, dt):
167
+ c = isinstance(m, Detect) # is final layer, copy input as inplace fix
168
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
169
+ t = time_sync()
170
+ for _ in range(10):
171
+ m(x.copy() if c else x)
172
+ dt.append((time_sync() - t) * 100)
173
+ if m == self.model[0]:
174
+ LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
175
+ LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
176
+ if c:
177
+ LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
178
+
179
  def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
180
  # https://arxiv.org/abs/1708.02002 section 3.3
181
  # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.