glenn-jocher commited on
Commit
0bd9c48
1 Parent(s): 394d1c8

Update torch_utils.py

Browse files
Files changed (1) hide show
  1. utils/torch_utils.py +3 -3
utils/torch_utils.py CHANGED
@@ -86,7 +86,7 @@ def profile(x, ops, n=100, device=None):
86
  x = x.to(device)
87
  x.requires_grad = True
88
  print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
89
- print(f"\n{'Params':>12s}{'FLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
90
  for m in ops if isinstance(ops, list) else [ops]:
91
  m = m.to(device) if hasattr(m, 'to') else m
92
  dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
@@ -197,9 +197,9 @@ def model_info(model, verbose=False, img_size=640):
197
  from thop import profile
198
  stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
199
  img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
200
- flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
201
  img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
202
- fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
203
  except (ImportError, Exception):
204
  fs = ''
205
 
 
86
  x = x.to(device)
87
  x.requires_grad = True
88
  print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
89
+ print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
90
  for m in ops if isinstance(ops, list) else [ops]:
91
  m = m.to(device) if hasattr(m, 'to') else m
92
  dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
 
197
  from thop import profile
198
  stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
199
  img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
200
+ flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
201
  img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
202
+ fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
203
  except (ImportError, Exception):
204
  fs = ''
205