|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict, Counter, defaultdict |
|
import json |
|
import os |
|
from posixpath import join |
|
import sys |
|
|
|
|
|
sys.path.append(os.path.dirname(sys.path[0])) |
|
|
|
import numpy as np |
|
from numpy import prod |
|
from itertools import zip_longest |
|
import tqdm |
|
import logging |
|
import typing |
|
import torch |
|
import torch.nn as nn |
|
from functools import partial |
|
import time |
|
|
|
from util.slconfig import SLConfig |
|
|
|
from typing import Any, Callable, List, Optional, Union |
|
from numbers import Number |
|
|
|
Handle = Callable[[List[Any], List[Any]], Union[typing.Counter[str], Number]] |
|
|
|
from main import build_model_main, get_args_parser as get_main_args_parser |
|
from datasets import build_dataset |
|
|
|
|
|
def get_shape(val: object) -> typing.List[int]: |
|
""" |
|
Get the shapes from a jit value object. |
|
Args: |
|
val (torch._C.Value): jit value object. |
|
Returns: |
|
list(int): return a list of ints. |
|
""" |
|
if val.isCompleteTensor(): |
|
r = val.type().sizes() |
|
if not r: |
|
r = [1] |
|
return r |
|
elif val.type().kind() in ("IntType", "FloatType"): |
|
return [1] |
|
elif val.type().kind() in ("StringType",): |
|
return [0] |
|
elif val.type().kind() in ("ListType",): |
|
return [1] |
|
elif val.type().kind() in ("BoolType", "NoneType"): |
|
return [0] |
|
else: |
|
raise ValueError() |
|
|
|
|
|
def addmm_flop_jit( |
|
inputs: typing.List[object], outputs: typing.List[object] |
|
) -> typing.Counter[str]: |
|
""" |
|
This method counts the flops for fully connected layers with torch script. |
|
Args: |
|
inputs (list(torch._C.Value)): The input shape in the form of a list of |
|
jit object. |
|
outputs (list(torch._C.Value)): The output shape in the form of a list |
|
of jit object. |
|
Returns: |
|
Counter: A Counter dictionary that records the number of flops for each |
|
operation. |
|
""" |
|
|
|
|
|
input_shapes = [get_shape(v) for v in inputs[1:3]] |
|
|
|
|
|
assert len(input_shapes[0]) == 2 |
|
assert len(input_shapes[1]) == 2 |
|
batch_size, input_dim = input_shapes[0] |
|
output_dim = input_shapes[1][1] |
|
flop = batch_size * input_dim * output_dim |
|
flop_counter = Counter({"addmm": flop}) |
|
return flop_counter |
|
|
|
|
|
def bmm_flop_jit(inputs, outputs): |
|
|
|
|
|
input_shapes = [get_shape(v) for v in inputs] |
|
|
|
|
|
assert len(input_shapes[0]) == 3 |
|
assert len(input_shapes[1]) == 3 |
|
T, batch_size, input_dim = input_shapes[0] |
|
output_dim = input_shapes[1][2] |
|
flop = T * batch_size * input_dim * output_dim |
|
flop_counter = Counter({"bmm": flop}) |
|
return flop_counter |
|
|
|
|
|
def basic_binary_op_flop_jit(inputs, outputs, name): |
|
input_shapes = [get_shape(v) for v in inputs] |
|
|
|
input_shapes = [s[::-1] for s in input_shapes] |
|
max_shape = np.array(list(zip_longest(*input_shapes, fillvalue=1))).max(1) |
|
flop = prod(max_shape) |
|
flop_counter = Counter({name: flop}) |
|
return flop_counter |
|
|
|
|
|
def rsqrt_flop_jit(inputs, outputs): |
|
input_shapes = [get_shape(v) for v in inputs] |
|
flop = prod(input_shapes[0]) * 2 |
|
flop_counter = Counter({"rsqrt": flop}) |
|
return flop_counter |
|
|
|
|
|
def dropout_flop_jit(inputs, outputs): |
|
input_shapes = [get_shape(v) for v in inputs[:1]] |
|
flop = prod(input_shapes[0]) |
|
flop_counter = Counter({"dropout": flop}) |
|
return flop_counter |
|
|
|
|
|
def softmax_flop_jit(inputs, outputs): |
|
|
|
input_shapes = [get_shape(v) for v in inputs[:1]] |
|
flop = prod(input_shapes[0]) * 5 |
|
flop_counter = Counter({"softmax": flop}) |
|
return flop_counter |
|
|
|
|
|
def _reduction_op_flop_jit(inputs, outputs, reduce_flops=1, finalize_flops=0): |
|
input_shapes = [get_shape(v) for v in inputs] |
|
output_shapes = [get_shape(v) for v in outputs] |
|
|
|
in_elements = prod(input_shapes[0]) |
|
out_elements = prod(output_shapes[0]) |
|
|
|
num_flops = in_elements * reduce_flops + out_elements * ( |
|
finalize_flops - reduce_flops |
|
) |
|
|
|
return num_flops |
|
|
|
|
|
def conv_flop_count( |
|
x_shape: typing.List[int], |
|
w_shape: typing.List[int], |
|
out_shape: typing.List[int], |
|
) -> typing.Counter[str]: |
|
""" |
|
This method counts the flops for convolution. Note only multiplication is |
|
counted. Computation for addition and bias is ignored. |
|
Args: |
|
x_shape (list(int)): The input shape before convolution. |
|
w_shape (list(int)): The filter shape. |
|
out_shape (list(int)): The output shape after convolution. |
|
Returns: |
|
Counter: A Counter dictionary that records the number of flops for each |
|
operation. |
|
""" |
|
batch_size, Cin_dim, Cout_dim = x_shape[0], w_shape[1], out_shape[1] |
|
out_size = prod(out_shape[2:]) |
|
kernel_size = prod(w_shape[2:]) |
|
flop = batch_size * out_size * Cout_dim * Cin_dim * kernel_size |
|
flop_counter = Counter({"conv": flop}) |
|
return flop_counter |
|
|
|
|
|
def conv_flop_jit( |
|
inputs: typing.List[object], outputs: typing.List[object] |
|
) -> typing.Counter[str]: |
|
""" |
|
This method counts the flops for convolution using torch script. |
|
Args: |
|
inputs (list(torch._C.Value)): The input shape in the form of a list of |
|
jit object before convolution. |
|
outputs (list(torch._C.Value)): The output shape in the form of a list |
|
of jit object after convolution. |
|
Returns: |
|
Counter: A Counter dictionary that records the number of flops for each |
|
operation. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
x, w = inputs[:2] |
|
x_shape, w_shape, out_shape = ( |
|
get_shape(x), |
|
get_shape(w), |
|
get_shape(outputs[0]), |
|
) |
|
return conv_flop_count(x_shape, w_shape, out_shape) |
|
|
|
|
|
def einsum_flop_jit( |
|
inputs: typing.List[object], outputs: typing.List[object] |
|
) -> typing.Counter[str]: |
|
""" |
|
This method counts the flops for the einsum operation. We currently support |
|
two einsum operations: "nct,ncp->ntp" and "ntg,ncg->nct". |
|
Args: |
|
inputs (list(torch._C.Value)): The input shape in the form of a list of |
|
jit object before einsum. |
|
outputs (list(torch._C.Value)): The output shape in the form of a list |
|
of jit object after einsum. |
|
Returns: |
|
Counter: A Counter dictionary that records the number of flops for each |
|
operation. |
|
""" |
|
|
|
|
|
|
|
assert len(inputs) == 2 |
|
equation = inputs[0].toIValue() |
|
|
|
equation = equation.replace(" ", "") |
|
|
|
|
|
letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys() |
|
mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)} |
|
equation = equation.translate(mapping) |
|
input_shapes_jit = inputs[1].node().inputs() |
|
input_shapes = [get_shape(v) for v in input_shapes_jit] |
|
|
|
if equation == "abc,abd->acd": |
|
n, c, t = input_shapes[0] |
|
p = input_shapes[-1][-1] |
|
flop = n * c * t * p |
|
flop_counter = Counter({"einsum": flop}) |
|
return flop_counter |
|
|
|
elif equation == "abc,adc->adb": |
|
n, t, g = input_shapes[0] |
|
c = input_shapes[-1][1] |
|
flop = n * t * g * c |
|
flop_counter = Counter({"einsum": flop}) |
|
return flop_counter |
|
|
|
else: |
|
raise NotImplementedError("Unsupported einsum operation.") |
|
|
|
|
|
def matmul_flop_jit( |
|
inputs: typing.List[object], outputs: typing.List[object] |
|
) -> typing.Counter[str]: |
|
""" |
|
This method counts the flops for matmul. |
|
Args: |
|
inputs (list(torch._C.Value)): The input shape in the form of a list of |
|
jit object before matmul. |
|
outputs (list(torch._C.Value)): The output shape in the form of a list |
|
of jit object after matmul. |
|
Returns: |
|
Counter: A Counter dictionary that records the number of flops for each |
|
operation. |
|
""" |
|
|
|
|
|
input_shapes = [get_shape(v) for v in inputs] |
|
assert len(input_shapes) == 2 |
|
assert input_shapes[0][-1] == input_shapes[1][-2] |
|
|
|
dim_len = len(input_shapes[1]) |
|
assert dim_len >= 2 |
|
batch = 1 |
|
for i in range(dim_len - 2): |
|
assert input_shapes[0][i] == input_shapes[1][i] |
|
batch *= input_shapes[0][i] |
|
|
|
|
|
flop = batch * input_shapes[0][-2] * input_shapes[0][-1] * input_shapes[1][-1] |
|
flop_counter = Counter({"matmul": flop}) |
|
return flop_counter |
|
|
|
|
|
def batchnorm_flop_jit( |
|
inputs: typing.List[object], outputs: typing.List[object] |
|
) -> typing.Counter[str]: |
|
""" |
|
This method counts the flops for batch norm. |
|
Args: |
|
inputs (list(torch._C.Value)): The input shape in the form of a list of |
|
jit object before batch norm. |
|
outputs (list(torch._C.Value)): The output shape in the form of a list |
|
of jit object after batch norm. |
|
Returns: |
|
Counter: A Counter dictionary that records the number of flops for each |
|
operation. |
|
""" |
|
|
|
input_shape = get_shape(inputs[0]) |
|
assert 2 <= len(input_shape) <= 5 |
|
flop = prod(input_shape) * 4 |
|
flop_counter = Counter({"batchnorm": flop}) |
|
return flop_counter |
|
|
|
|
|
def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: |
|
""" |
|
Count flops for the aten::linear operator. |
|
""" |
|
|
|
|
|
input_shapes = [get_shape(v) for v in inputs[0:2]] |
|
|
|
|
|
assert input_shapes[0][-1] == input_shapes[1][-1] |
|
flops = prod(input_shapes[0]) * input_shapes[1][0] |
|
flop_counter = Counter({"linear": flops}) |
|
return flop_counter |
|
|
|
|
|
def norm_flop_counter(affine_arg_index: int) -> Handle: |
|
""" |
|
Args: |
|
affine_arg_index: index of the affine argument in inputs |
|
""" |
|
|
|
def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: |
|
""" |
|
Count flops for norm layers. |
|
""" |
|
|
|
input_shape = get_shape(inputs[0]) |
|
has_affine = get_shape(inputs[affine_arg_index]) is not None |
|
assert 2 <= len(input_shape) <= 5, input_shape |
|
|
|
flop = prod(input_shape) * (5 if has_affine else 4) |
|
flop_counter = Counter({"norm": flop}) |
|
return flop_counter |
|
|
|
return norm_flop_jit |
|
|
|
|
|
def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Handle: |
|
""" |
|
Count flops by |
|
input_tensor.numel() * input_scale + output_tensor.numel() * output_scale |
|
|
|
Args: |
|
input_scale: scale of the input tensor (first argument) |
|
output_scale: scale of the output tensor (first element in outputs) |
|
""" |
|
|
|
def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number: |
|
ret = 0 |
|
if input_scale != 0: |
|
shape = get_shape(inputs[0]) |
|
ret += input_scale * prod(shape) |
|
if output_scale != 0: |
|
shape = get_shape(outputs[0]) |
|
ret += output_scale * prod(shape) |
|
flop_counter = Counter({"elementwise": ret}) |
|
return flop_counter |
|
|
|
return elementwise_flop |
|
|
|
|
|
|
|
_SUPPORTED_OPS: typing.Dict[str, typing.Callable] = { |
|
"aten::addmm": addmm_flop_jit, |
|
"aten::_convolution": conv_flop_jit, |
|
"aten::einsum": einsum_flop_jit, |
|
"aten::matmul": matmul_flop_jit, |
|
"aten::batch_norm": batchnorm_flop_jit, |
|
"aten::bmm": bmm_flop_jit, |
|
"aten::add": partial(basic_binary_op_flop_jit, name="aten::add"), |
|
"aten::add_": partial(basic_binary_op_flop_jit, name="aten::add_"), |
|
"aten::mul": partial(basic_binary_op_flop_jit, name="aten::mul"), |
|
"aten::sub": partial(basic_binary_op_flop_jit, name="aten::sub"), |
|
"aten::div": partial(basic_binary_op_flop_jit, name="aten::div"), |
|
"aten::floor_divide": partial(basic_binary_op_flop_jit, name="aten::floor_divide"), |
|
"aten::relu": partial(basic_binary_op_flop_jit, name="aten::relu"), |
|
"aten::relu_": partial(basic_binary_op_flop_jit, name="aten::relu_"), |
|
"aten::sigmoid": partial(basic_binary_op_flop_jit, name="aten::sigmoid"), |
|
"aten::log": partial(basic_binary_op_flop_jit, name="aten::log"), |
|
"aten::sum": partial(basic_binary_op_flop_jit, name="aten::sum"), |
|
"aten::sin": partial(basic_binary_op_flop_jit, name="aten::sin"), |
|
"aten::cos": partial(basic_binary_op_flop_jit, name="aten::cos"), |
|
"aten::pow": partial(basic_binary_op_flop_jit, name="aten::pow"), |
|
"aten::cumsum": partial(basic_binary_op_flop_jit, name="aten::cumsum"), |
|
"aten::rsqrt": rsqrt_flop_jit, |
|
"aten::softmax": softmax_flop_jit, |
|
"aten::dropout": dropout_flop_jit, |
|
"aten::linear": linear_flop_jit, |
|
"aten::group_norm": norm_flop_counter(2), |
|
"aten::layer_norm": norm_flop_counter(2), |
|
"aten::instance_norm": norm_flop_counter(1), |
|
"aten::upsample_nearest2d": elementwise_flop_counter(0, 1), |
|
"aten::upsample_bilinear2d": elementwise_flop_counter(0, 4), |
|
"aten::adaptive_avg_pool2d": elementwise_flop_counter(1, 0), |
|
"aten::max_pool2d": elementwise_flop_counter(1, 0), |
|
"aten::mm": matmul_flop_jit, |
|
} |
|
|
|
|
|
|
|
_IGNORED_OPS: typing.List[str] = [ |
|
"aten::Int", |
|
"aten::__and__", |
|
"aten::arange", |
|
"aten::cat", |
|
"aten::clamp", |
|
"aten::clamp_", |
|
"aten::contiguous", |
|
"aten::copy_", |
|
"aten::detach", |
|
"aten::empty", |
|
"aten::eq", |
|
"aten::expand", |
|
"aten::flatten", |
|
"aten::floor", |
|
"aten::full", |
|
"aten::gt", |
|
"aten::index", |
|
"aten::index_put_", |
|
"aten::max", |
|
"aten::nonzero", |
|
"aten::permute", |
|
"aten::remainder", |
|
"aten::reshape", |
|
"aten::select", |
|
"aten::gather", |
|
"aten::topk", |
|
"aten::meshgrid", |
|
"aten::masked_fill", |
|
"aten::linspace", |
|
"aten::size", |
|
"aten::slice", |
|
"aten::split_with_sizes", |
|
"aten::squeeze", |
|
"aten::t", |
|
"aten::to", |
|
"aten::transpose", |
|
"aten::unsqueeze", |
|
"aten::view", |
|
"aten::zeros", |
|
"aten::zeros_like", |
|
"aten::ones_like", |
|
"aten::new_zeros", |
|
"aten::all", |
|
"prim::Constant", |
|
"prim::Int", |
|
"prim::ListConstruct", |
|
"prim::ListUnpack", |
|
"prim::NumToTensor", |
|
"prim::TupleConstruct", |
|
"aten::stack", |
|
"aten::chunk", |
|
"aten::repeat", |
|
"aten::grid_sampler", |
|
"aten::constant_pad_nd", |
|
] |
|
|
|
_HAS_ALREADY_SKIPPED = False |
|
|
|
|
|
def flop_count( |
|
model: nn.Module, |
|
inputs: typing.Tuple[object, ...], |
|
whitelist: typing.Union[typing.List[str], None] = None, |
|
customized_ops: typing.Union[typing.Dict[str, typing.Callable], None] = None, |
|
) -> typing.DefaultDict[str, float]: |
|
""" |
|
Given a model and an input to the model, compute the Gflops of the given |
|
model. Note the input should have a batch size of 1. |
|
Args: |
|
model (nn.Module): The model to compute flop counts. |
|
inputs (tuple): Inputs that are passed to `model` to count flops. |
|
Inputs need to be in a tuple. |
|
whitelist (list(str)): Whitelist of operations that will be counted. It |
|
needs to be a subset of _SUPPORTED_OPS. By default, the function |
|
computes flops for all supported operations. |
|
customized_ops (dict(str,Callable)) : A dictionary contains customized |
|
operations and their flop handles. If customized_ops contains an |
|
operation in _SUPPORTED_OPS, then the default handle in |
|
_SUPPORTED_OPS will be overwritten. |
|
Returns: |
|
defaultdict: A dictionary that records the number of gflops for each |
|
operation. |
|
""" |
|
|
|
|
|
flop_count_ops = _SUPPORTED_OPS.copy() |
|
if customized_ops: |
|
flop_count_ops.update(customized_ops) |
|
|
|
|
|
if whitelist is None: |
|
whitelist_set = set(flop_count_ops.keys()) |
|
else: |
|
whitelist_set = set(whitelist) |
|
|
|
|
|
if isinstance( |
|
model, |
|
(nn.parallel.distributed.DistributedDataParallel, nn.DataParallel), |
|
): |
|
model = model.module |
|
|
|
assert set(whitelist_set).issubset( |
|
flop_count_ops |
|
), "whitelist needs to be a subset of _SUPPORTED_OPS and customized_ops." |
|
assert isinstance(inputs, tuple), "Inputs need to be in a tuple." |
|
|
|
|
|
if hasattr(torch.jit, "get_trace_graph"): |
|
trace, _ = torch.jit.get_trace_graph(model, inputs) |
|
trace_nodes = trace.graph().nodes() |
|
else: |
|
trace, _ = torch.jit._get_trace_graph(model, inputs) |
|
trace_nodes = trace.nodes() |
|
|
|
skipped_ops = Counter() |
|
total_flop_counter = Counter() |
|
|
|
for node in trace_nodes: |
|
kind = node.kind() |
|
if kind not in whitelist_set: |
|
|
|
if kind not in _IGNORED_OPS: |
|
skipped_ops[kind] += 1 |
|
continue |
|
|
|
handle_count = flop_count_ops.get(kind, None) |
|
if handle_count is None: |
|
continue |
|
|
|
inputs, outputs = list(node.inputs()), list(node.outputs()) |
|
flops_counter = handle_count(inputs, outputs) |
|
total_flop_counter += flops_counter |
|
|
|
global _HAS_ALREADY_SKIPPED |
|
if len(skipped_ops) > 0 and not _HAS_ALREADY_SKIPPED: |
|
_HAS_ALREADY_SKIPPED = True |
|
for op, freq in skipped_ops.items(): |
|
logging.warning("Skipped operation {} {} time(s)".format(op, freq)) |
|
|
|
|
|
final_count = defaultdict(float) |
|
for op in total_flop_counter: |
|
final_count[op] = total_flop_counter[op] / 1e9 |
|
|
|
return final_count |
|
|
|
|
|
def get_dataset(coco_path): |
|
""" |
|
Gets the COCO dataset used for computing the flops on |
|
""" |
|
|
|
class DummyArgs: |
|
pass |
|
|
|
args = DummyArgs() |
|
args.dataset_file = "coco" |
|
args.coco_path = coco_path |
|
args.masks = False |
|
dataset = build_dataset(image_set="val", args=args) |
|
return dataset |
|
|
|
|
|
def warmup(model, inputs, N=10): |
|
for i in range(N): |
|
out = model(inputs) |
|
torch.cuda.synchronize() |
|
|
|
|
|
def measure_time(model, inputs, N=10): |
|
warmup(model, inputs) |
|
s = time.time() |
|
for i in range(N): |
|
out = model(inputs) |
|
torch.cuda.synchronize() |
|
t = (time.time() - s) / N |
|
return t |
|
|
|
|
|
def fmt_res(data): |
|
|
|
return { |
|
"mean": data.mean(), |
|
"std": data.std(), |
|
"min": data.min(), |
|
"max": data.max(), |
|
} |
|
|
|
|
|
def benchmark(): |
|
_outputs = {} |
|
main_args = get_main_args_parser().parse_args() |
|
main_args.commad_txt = "Command: " + " ".join(sys.argv) |
|
|
|
|
|
print("Loading config file from {}".format(main_args.config_file)) |
|
cfg = SLConfig.fromfile(main_args.config_file) |
|
if main_args.options is not None: |
|
cfg.merge_from_dict(main_args.options) |
|
cfg_dict = cfg._cfg_dict.to_dict() |
|
args_vars = vars(main_args) |
|
for k, v in cfg_dict.items(): |
|
if k not in args_vars: |
|
setattr(main_args, k, v) |
|
else: |
|
raise ValueError("Key {} can used by args only".format(k)) |
|
|
|
dataset = build_dataset("val", main_args) |
|
model, _, _ = build_model_main(main_args) |
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
_outputs.update({"nparam": n_parameters}) |
|
|
|
model.cuda() |
|
model.eval() |
|
|
|
warmup_step = 5 |
|
total_step = 20 |
|
|
|
images = [] |
|
for idx in range(total_step): |
|
img, t = dataset[idx] |
|
images.append(img) |
|
|
|
with torch.no_grad(): |
|
tmp = [] |
|
tmp2 = [] |
|
for imgid, img in enumerate(tqdm.tqdm(images)): |
|
inputs = [img.to("cuda")] |
|
res = flop_count(model, (inputs,)) |
|
t = measure_time(model, inputs) |
|
tmp.append(sum(res.values())) |
|
if imgid >= warmup_step: |
|
tmp2.append(t) |
|
_outputs.update({"detailed_flops": res}) |
|
_outputs.update({"flops": fmt_res(np.array(tmp)), "time": fmt_res(np.array(tmp2))}) |
|
|
|
mean_infer_time = float(fmt_res(np.array(tmp2))["mean"]) |
|
_outputs.update({"fps": 1 / mean_infer_time}) |
|
|
|
res = {"flops": fmt_res(np.array(tmp)), "time": fmt_res(np.array(tmp2))} |
|
|
|
|
|
output_file = os.path.join(main_args.output_dir, "flops", "log.txt") |
|
os.makedirs(os.path.dirname(output_file), exist_ok=True) |
|
with open(output_file, "a") as f: |
|
f.write(main_args.commad_txt + "\n") |
|
f.write(json.dumps(_outputs, indent=2) + "\n") |
|
|
|
return _outputs |
|
|
|
|
|
if __name__ == "__main__": |
|
res = benchmark() |
|
print(json.dumps(res, indent=2)) |
|
|