zdou0830's picture
desco
749745d
import argparse
import logging
import torch
import torch.nn as nn
import timeit
from maskrcnn_benchmark.layers import *
from maskrcnn_benchmark.modeling.backbone.resnet_big import StdConv2d
from maskrcnn_benchmark.modeling.backbone.fpn import *
from maskrcnn_benchmark.modeling.rpn.inference import *
from maskrcnn_benchmark.modeling.roi_heads.box_head.inference import PostProcessor
from maskrcnn_benchmark.modeling.rpn.anchor_generator import BufferList
def profile(model, input_size, custom_ops={}, device="cpu", verbose=False, extra_args={}, return_time=False):
handler_collection = []
def add_hooks(m):
if len(list(m.children())) > 0:
return
m.register_buffer("total_ops", torch.zeros(1))
m.register_buffer("total_params", torch.zeros(1))
for p in m.parameters():
m.total_params += torch.Tensor([p.numel()])
m_type = type(m)
fn = None
if m_type in custom_ops:
fn = custom_ops[m_type]
elif m_type in register_hooks:
fn = register_hooks[m_type]
else:
print("Not implemented for ", m)
if fn is not None:
if verbose:
print("Register FLOP counter for module %s" % str(m))
handler = m.register_forward_hook(fn)
handler_collection.append(handler)
original_device = model.parameters().__next__().device
training = model.training
model.eval().to(device)
model.apply(add_hooks)
x = torch.zeros(input_size).to(device)
with torch.no_grad():
tic = timeit.time.perf_counter()
model(x, **extra_args)
toc = timeit.time.perf_counter()
total_time = toc - tic
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
continue
total_ops += m.total_ops
total_params += m.total_params
total_ops = total_ops.item()
total_params = total_params.item()
model.train(training).to(original_device)
for handler in handler_collection:
handler.remove()
if return_time:
return total_ops, total_params, total_time
else:
return total_ops, total_params
multiply_adds = 1
def count_conv2d(m, x, y):
x = x[0]
cin = m.in_channels
cout = m.out_channels
kh, kw = m.kernel_size
batch_size = x.size()[0]
out_h = y.size(2)
out_w = y.size(3)
# ops per output element
# kernel_mul = kh * kw * cin
# kernel_add = kh * kw * cin - 1
kernel_ops = multiply_adds * kh * kw * cin // m.groups
bias_ops = 1 if m.bias is not None else 0
ops_per_element = kernel_ops + bias_ops
# total ops
# num_out_elements = y.numel()
output_elements = batch_size * out_w * out_h * cout
total_ops = output_elements * ops_per_element
m.total_ops = torch.Tensor([int(total_ops)])
def count_convtranspose2d(m, x, y):
x = x[0]
cin = m.in_channels
cout = m.out_channels
kh, kw = m.kernel_size
batch_size = x.size()[0]
out_h = y.size(2)
out_w = y.size(3)
# ops per output element
# kernel_mul = kh * kw * cin
# kernel_add = kh * kw * cin - 1
kernel_ops = multiply_adds * kh * kw * cin // m.groups
bias_ops = 1 if m.bias is not None else 0
ops_per_element = kernel_ops + bias_ops
# total ops
# num_out_elements = y.numel()
# output_elements = batch_size * out_w * out_h * cout
ops_per_element = m.weight.nelement()
output_elements = y.nelement()
total_ops = output_elements * ops_per_element
m.total_ops = torch.Tensor([int(total_ops)])
def count_bn(m, x, y):
x = x[0]
nelements = x.numel()
# subtract, divide, gamma, beta
total_ops = 4 * nelements
m.total_ops = torch.Tensor([int(total_ops)])
def count_relu(m, x, y):
x = x[0]
nelements = x.numel()
total_ops = nelements
m.total_ops = torch.Tensor([int(total_ops)])
def count_softmax(m, x, y):
x = x[0]
batch_size, nfeatures = x.size()
total_exp = nfeatures
total_add = nfeatures - 1
total_div = nfeatures
total_ops = batch_size * (total_exp + total_add + total_div)
m.total_ops = torch.Tensor([int(total_ops)])
def count_maxpool(m, x, y):
kernel_ops = torch.prod(torch.Tensor([m.kernel_size]))
num_elements = y.numel()
total_ops = kernel_ops * num_elements
m.total_ops = torch.Tensor([int(total_ops)])
def count_adap_maxpool(m, x, y):
kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
kernel_ops = torch.prod(kernel)
num_elements = y.numel()
total_ops = kernel_ops * num_elements
m.total_ops = torch.Tensor([int(total_ops)])
def count_avgpool(m, x, y):
total_add = torch.prod(torch.Tensor([m.kernel_size]))
total_div = 1
kernel_ops = total_add + total_div
num_elements = y.numel()
total_ops = kernel_ops * num_elements
m.total_ops = torch.Tensor([int(total_ops)])
def count_adap_avgpool(m, x, y):
kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
total_add = torch.prod(kernel)
total_div = 1
kernel_ops = total_add + total_div
num_elements = y.numel()
total_ops = kernel_ops * num_elements
m.total_ops = torch.Tensor([int(total_ops)])
def count_linear(m, x, y):
# per output element
total_mul = m.in_features
total_add = m.in_features - 1
num_elements = y.numel()
total_ops = (total_mul + total_add) * num_elements
m.total_ops = torch.Tensor([int(total_ops)])
def count_LastLevelMaxPool(m, x, y):
num_elements = y[-1].numel()
total_ops = num_elements
m.total_ops = torch.Tensor([int(total_ops)])
def count_ROIAlign(m, x, y):
num_elements = y.numel()
total_ops = num_elements * 4
m.total_ops = torch.Tensor([int(total_ops)])
register_hooks = {
Scale: None,
Conv2d: count_conv2d,
nn.Conv2d: count_conv2d,
ModulatedDeformConv: count_conv2d,
StdConv2d: count_conv2d,
nn.BatchNorm1d: count_bn,
nn.BatchNorm2d: count_bn,
nn.BatchNorm3d: count_bn,
FrozenBatchNorm2d: count_bn,
nn.GroupNorm: count_bn,
NaiveSyncBatchNorm2d: count_bn,
nn.ReLU: count_relu,
nn.ReLU6: count_relu,
swish: None,
nn.ConstantPad2d: None,
SPPLayer: count_LastLevelMaxPool,
LastLevelMaxPool: count_LastLevelMaxPool,
nn.MaxPool1d: count_maxpool,
nn.MaxPool2d: count_maxpool,
nn.MaxPool3d: count_maxpool,
nn.AdaptiveMaxPool1d: count_adap_maxpool,
nn.AdaptiveMaxPool2d: count_adap_maxpool,
nn.AdaptiveMaxPool3d: count_adap_maxpool,
nn.AvgPool1d: count_avgpool,
nn.AvgPool2d: count_avgpool,
nn.AvgPool3d: count_avgpool,
nn.AdaptiveAvgPool1d: count_adap_avgpool,
nn.AdaptiveAvgPool2d: count_adap_avgpool,
nn.AdaptiveAvgPool3d: count_adap_avgpool,
nn.Linear: count_linear,
nn.Upsample: None,
nn.Dropout: None,
nn.Sigmoid: None,
DropBlock2D: None,
ROIAlign: count_ROIAlign,
RPNPostProcessor: None,
PostProcessor: None,
BufferList: None,
RetinaPostProcessor: None,
FCOSPostProcessor: None,
ATSSPostProcessor: None,
}