glenn-jocher
commited on
Commit
•
ada90e3
1
Parent(s):
94a7f55
Profile() feature addition (#1673)
Browse files* Profile() feature addition
* cleanup
- utils/torch_utils.py +40 -1
utils/torch_utils.py
CHANGED
@@ -1,18 +1,22 @@
|
|
1 |
# PyTorch utils
|
2 |
|
3 |
import logging
|
|
|
4 |
import os
|
5 |
import time
|
6 |
from contextlib import contextmanager
|
7 |
from copy import deepcopy
|
8 |
|
9 |
-
import math
|
10 |
import torch
|
11 |
import torch.backends.cudnn as cudnn
|
12 |
import torch.nn as nn
|
13 |
import torch.nn.functional as F
|
14 |
import torchvision
|
15 |
|
|
|
|
|
|
|
|
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
|
@@ -66,10 +70,45 @@ def select_device(device='', batch_size=None):
|
|
66 |
|
67 |
|
68 |
def time_synchronized():
|
|
|
69 |
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
70 |
return time.time()
|
71 |
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def is_parallel(model):
|
74 |
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
75 |
|
|
|
1 |
# PyTorch utils
|
2 |
|
3 |
import logging
|
4 |
+
import math
|
5 |
import os
|
6 |
import time
|
7 |
from contextlib import contextmanager
|
8 |
from copy import deepcopy
|
9 |
|
|
|
10 |
import torch
|
11 |
import torch.backends.cudnn as cudnn
|
12 |
import torch.nn as nn
|
13 |
import torch.nn.functional as F
|
14 |
import torchvision
|
15 |
|
16 |
+
try:
|
17 |
+
import thop # for FLOPS computation
|
18 |
+
except ImportError:
|
19 |
+
thop = None
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
|
|
|
70 |
|
71 |
|
72 |
def time_synchronized():
|
73 |
+
# pytorch-accurate time
|
74 |
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
75 |
return time.time()
|
76 |
|
77 |
|
78 |
+
def profile(x, ops, n=100, device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
|
79 |
+
# profile a pytorch module or list of modules. Example usage:
|
80 |
+
# x = torch.randn(16, 3, 640, 640) # input
|
81 |
+
# m1 = lambda x: x * torch.sigmoid(x)
|
82 |
+
# m2 = nn.SiLU()
|
83 |
+
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
|
84 |
+
|
85 |
+
x = x.to(device)
|
86 |
+
x.requires_grad = True
|
87 |
+
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
88 |
+
print(f"\n{'Params':>12s}{'FLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
|
89 |
+
for m in ops if isinstance(ops, list) else [ops]:
|
90 |
+
m = m.to(device) if hasattr(m, 'to') else m
|
91 |
+
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
92 |
+
try:
|
93 |
+
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
|
94 |
+
except:
|
95 |
+
flops = 0
|
96 |
+
|
97 |
+
for _ in range(n):
|
98 |
+
t[0] = time_synchronized()
|
99 |
+
y = m(x)
|
100 |
+
t[1] = time_synchronized()
|
101 |
+
_ = y.sum().backward()
|
102 |
+
t[2] = time_synchronized()
|
103 |
+
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
104 |
+
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
105 |
+
|
106 |
+
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
|
107 |
+
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
|
108 |
+
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
|
109 |
+
print(f'{p:12.4g}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
110 |
+
|
111 |
+
|
112 |
def is_parallel(model):
|
113 |
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
114 |
|