glenn-jocher commited on
Commit
d8f5fcf
·
unverified ·
1 Parent(s): 0c26c4e

Improved FLOPS computation (#1398)

Browse files

* Improved FLOPS computation

* update comment

Files changed (2) hide show
  1. models/yolo.py +2 -2
  2. utils/torch_utils.py +6 -4
models/yolo.py CHANGED
@@ -192,8 +192,8 @@ class Model(nn.Module):
192
  copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
193
  return m
194
 
195
- def info(self, verbose=False): # print model information
196
- model_info(self, verbose)
197
 
198
 
199
  def parse_model(d, ch): # model_dict, input_channels(3)
 
192
  copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
193
  return m
194
 
195
+ def info(self, verbose=False, img_size=640): # print model information
196
+ model_info(self, verbose, img_size)
197
 
198
 
199
  def parse_model(d, ch): # model_dict, input_channels(3)
utils/torch_utils.py CHANGED
@@ -139,8 +139,8 @@ def fuse_conv_and_bn(conv, bn):
139
  return fusedconv
140
 
141
 
142
- def model_info(model, verbose=False):
143
- # Plots a line-by-line description of a PyTorch model
144
  n_p = sum(x.numel() for x in model.parameters()) # number parameters
145
  n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
146
  if verbose:
@@ -152,8 +152,10 @@ def model_info(model, verbose=False):
152
 
153
  try: # FLOPS
154
  from thop import profile
155
- flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
156
- fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
 
 
157
  except ImportError:
158
  fs = ''
159
 
 
139
  return fusedconv
140
 
141
 
142
+ def model_info(model, verbose=False, img_size=640):
143
+ # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
144
  n_p = sum(x.numel() for x in model.parameters()) # number parameters
145
  n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
146
  if verbose:
 
152
 
153
  try: # FLOPS
154
  from thop import profile
155
+ stride = int(model.stride.max())
156
+ flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, stride, stride),), verbose=False)[0] / 1E9 * 2
157
+ img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
158
+ fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
159
  except ImportError:
160
  fs = ''
161