Improved FLOPS computation (#1398)
Browse files* Improved FLOPS computation
* update comment
- models/yolo.py +2 -2
- utils/torch_utils.py +6 -4
models/yolo.py
CHANGED
@@ -192,8 +192,8 @@ class Model(nn.Module):
|
|
192 |
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
|
193 |
return m
|
194 |
|
195 |
-
def info(self, verbose=False): # print model information
|
196 |
-
model_info(self, verbose)
|
197 |
|
198 |
|
199 |
def parse_model(d, ch): # model_dict, input_channels(3)
|
|
|
192 |
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
|
193 |
return m
|
194 |
|
195 |
+
def info(self, verbose=False, img_size=640): # print model information
|
196 |
+
model_info(self, verbose, img_size)
|
197 |
|
198 |
|
199 |
def parse_model(d, ch): # model_dict, input_channels(3)
|
utils/torch_utils.py
CHANGED
@@ -139,8 +139,8 @@ def fuse_conv_and_bn(conv, bn):
|
|
139 |
return fusedconv
|
140 |
|
141 |
|
142 |
-
def model_info(model, verbose=False):
|
143 |
-
#
|
144 |
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
145 |
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
146 |
if verbose:
|
@@ -152,8 +152,10 @@ def model_info(model, verbose=False):
|
|
152 |
|
153 |
try: # FLOPS
|
154 |
from thop import profile
|
155 |
-
|
156 |
-
|
|
|
|
|
157 |
except ImportError:
|
158 |
fs = ''
|
159 |
|
|
|
139 |
return fusedconv
|
140 |
|
141 |
|
142 |
+
def model_info(model, verbose=False, img_size=640):
|
143 |
+
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
|
144 |
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
145 |
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
146 |
if verbose:
|
|
|
152 |
|
153 |
try: # FLOPS
|
154 |
from thop import profile
|
155 |
+
stride = int(model.stride.max())
|
156 |
+
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, stride, stride),), verbose=False)[0] / 1E9 * 2
|
157 |
+
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
158 |
+
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
|
159 |
except ImportError:
|
160 |
fs = ''
|
161 |
|