glenn-jocher
commited on
Commit
•
0bd9c48
1
Parent(s):
394d1c8
Update torch_utils.py
Browse files- utils/torch_utils.py +3 -3
utils/torch_utils.py
CHANGED
@@ -86,7 +86,7 @@ def profile(x, ops, n=100, device=None):
|
|
86 |
x = x.to(device)
|
87 |
x.requires_grad = True
|
88 |
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
89 |
-
print(f"\n{'Params':>12s}{'
|
90 |
for m in ops if isinstance(ops, list) else [ops]:
|
91 |
m = m.to(device) if hasattr(m, 'to') else m
|
92 |
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
@@ -197,9 +197,9 @@ def model_info(model, verbose=False, img_size=640):
|
|
197 |
from thop import profile
|
198 |
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
|
199 |
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
|
200 |
-
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride
|
201 |
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
202 |
-
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640
|
203 |
except (ImportError, Exception):
|
204 |
fs = ''
|
205 |
|
|
|
86 |
x = x.to(device)
|
87 |
x.requires_grad = True
|
88 |
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
89 |
+
print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
|
90 |
for m in ops if isinstance(ops, list) else [ops]:
|
91 |
m = m.to(device) if hasattr(m, 'to') else m
|
92 |
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
|
|
197 |
from thop import profile
|
198 |
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
|
199 |
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
|
200 |
+
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
|
201 |
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
202 |
+
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
|
203 |
except (ImportError, Exception):
|
204 |
fs = ''
|
205 |
|