| | |
| |
|
| | import typing |
| | from collections import defaultdict |
| | from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union |
| |
|
| | import torch.nn as nn |
| | from rich import box |
| | from rich.console import Console |
| | from rich.table import Table |
| | from torch import Tensor |
| |
|
| | from .jit_analysis import JitModelAnalysis |
| | from .jit_handles import (Handle, addmm_flop_jit, batchnorm_flop_jit, |
| | bmm_flop_jit, conv_flop_jit, einsum_flop_jit, |
| | elementwise_flop_counter, generic_activation_jit, |
| | linear_flop_jit, matmul_flop_jit, norm_flop_counter) |
| |
|
| | |
| | _DEFAULT_SUPPORTED_FLOP_OPS: Dict[str, Handle] = { |
| | 'aten::addmm': addmm_flop_jit, |
| | 'aten::bmm': bmm_flop_jit, |
| | 'aten::_convolution': conv_flop_jit, |
| | 'aten::einsum': einsum_flop_jit, |
| | 'aten::matmul': matmul_flop_jit, |
| | 'aten::mm': matmul_flop_jit, |
| | 'aten::linear': linear_flop_jit, |
| | |
| | |
| | 'aten::batch_norm': batchnorm_flop_jit, |
| | 'aten::group_norm': norm_flop_counter(2), |
| | 'aten::layer_norm': norm_flop_counter(2), |
| | 'aten::instance_norm': norm_flop_counter(1), |
| | 'aten::upsample_nearest2d': elementwise_flop_counter(0, 1), |
| | 'aten::upsample_bilinear2d': elementwise_flop_counter(0, 4), |
| | 'aten::adaptive_avg_pool2d': elementwise_flop_counter(1, 0), |
| | 'aten::grid_sampler': elementwise_flop_counter(0, 4), |
| | } |
| |
|
| | |
| | |
| | _DEFAULT_SUPPORTED_ACT_OPS: Dict[str, Handle] = { |
| | 'aten::_convolution': generic_activation_jit('conv'), |
| | 'aten::addmm': generic_activation_jit(), |
| | 'aten::bmm': generic_activation_jit(), |
| | 'aten::einsum': generic_activation_jit(), |
| | 'aten::matmul': generic_activation_jit(), |
| | 'aten::linear': generic_activation_jit(), |
| | } |
| |
|
| |
|
| | class FlopAnalyzer(JitModelAnalysis): |
| | """Provides access to per-submodule model flop count obtained by tracing a |
| | model with pytorch's jit tracing functionality. |
| | |
| | By default, comes with standard flop counters for a few common operators. |
| | |
| | Note: |
| | - Flop is not a well-defined concept. We just produce our best |
| | estimate. |
| | - We count one fused multiply-add as one flop. |
| | |
| | Handles for additional operators may be added, or the default ones |
| | overwritten, using the ``.set_op_handle(name, func)`` method. |
| | See the method documentation for details. |
| | Flop counts can be obtained as: |
| | |
| | - ``.total(module_name="")``: total flop count for the module |
| | - ``.by_operator(module_name="")``: flop counts for the module, as a |
| | Counter over different operator types |
| | - ``.by_module()``: Counter of flop counts for all submodules |
| | - ``.by_module_and_operator()``: dictionary indexed by descendant of |
| | Counters over different operator types |
| | |
| | An operator is treated as within a module if it is executed inside the |
| | module's ``__call__`` method. Note that this does not include calls to |
| | other methods of the module or explicit calls to ``module.forward(...)``. |
| | |
| | Modified from |
| | https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py |
| | |
| | Args: |
| | model (nn.Module): The model to analyze. |
| | inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model. |
| | |
| | Examples: |
| | >>> import torch.nn as nn |
| | >>> import torch |
| | >>> class TestModel(nn.Module): |
| | ... def __init__(self): |
| | ... super().__init__() |
| | ... self.fc = nn.Linear(in_features=1000, out_features=10) |
| | ... self.conv = nn.Conv2d( |
| | ... in_channels=3, out_channels=10, kernel_size=1 |
| | ... ) |
| | ... self.act = nn.ReLU() |
| | ... def forward(self, x): |
| | ... return self.fc(self.act(self.conv(x)).flatten(1)) |
| | >>> model = TestModel() |
| | >>> inputs = (torch.randn((1,3,10,10)),) |
| | >>> flops = FlopAnalyzer(model, inputs) |
| | >>> flops.total() |
| | 13000 |
| | >>> flops.total("fc") |
| | 10000 |
| | >>> flops.by_operator() |
| | Counter({"addmm" : 10000, "conv" : 3000}) |
| | >>> flops.by_module() |
| | Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0}) |
| | >>> flops.by_module_and_operator() |
| | {"" : Counter({"addmm" : 10000, "conv" : 3000}), |
| | "fc" : Counter({"addmm" : 10000}), |
| | "conv" : Counter({"conv" : 3000}), |
| | "act" : Counter() |
| | } |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model: nn.Module, |
| | inputs: Union[Tensor, Tuple[Tensor, ...]], |
| | ) -> None: |
| | super().__init__(model=model, inputs=inputs) |
| | self.set_op_handle(**_DEFAULT_SUPPORTED_FLOP_OPS) |
| |
|
| | __init__.__doc__ = JitModelAnalysis.__init__.__doc__ |
| |
|
| |
|
| | class ActivationAnalyzer(JitModelAnalysis): |
| | """Provides access to per-submodule model activation count obtained by |
| | tracing a model with pytorch's jit tracing functionality. |
| | |
| | By default, comes with standard activation counters for convolutional and |
| | dot-product operators. Handles for additional operators may be added, or |
| | the default ones overwritten, using the ``.set_op_handle(name, func)`` |
| | method. See the method documentation for details. Activation counts can be |
| | obtained as: |
| | |
| | - ``.total(module_name="")``: total activation count for a module |
| | - ``.by_operator(module_name="")``: activation counts for the module, |
| | as a Counter over different operator types |
| | - ``.by_module()``: Counter of activation counts for all submodules |
| | - ``.by_module_and_operator()``: dictionary indexed by descendant of |
| | Counters over different operator types |
| | |
| | An operator is treated as within a module if it is executed inside the |
| | module's ``__call__`` method. Note that this does not include calls to |
| | other methods of the module or explicit calls to ``module.forward(...)``. |
| | |
| | Modified from |
| | https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py |
| | |
| | Args: |
| | model (nn.Module): The model to analyze. |
| | inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model. |
| | |
| | Examples: |
| | >>> import torch.nn as nn |
| | >>> import torch |
| | >>> class TestModel(nn.Module): |
| | ... def __init__(self): |
| | ... super().__init__() |
| | ... self.fc = nn.Linear(in_features=1000, out_features=10) |
| | ... self.conv = nn.Conv2d( |
| | ... in_channels=3, out_channels=10, kernel_size=1 |
| | ... ) |
| | ... self.act = nn.ReLU() |
| | ... def forward(self, x): |
| | ... return self.fc(self.act(self.conv(x)).flatten(1)) |
| | >>> model = TestModel() |
| | >>> inputs = (torch.randn((1,3,10,10)),) |
| | >>> acts = ActivationAnalyzer(model, inputs) |
| | >>> acts.total() |
| | 1010 |
| | >>> acts.total("fc") |
| | 10 |
| | >>> acts.by_operator() |
| | Counter({"conv" : 1000, "addmm" : 10}) |
| | >>> acts.by_module() |
| | Counter({"" : 1010, "fc" : 10, "conv" : 1000, "act" : 0}) |
| | >>> acts.by_module_and_operator() |
| | {"" : Counter({"conv" : 1000, "addmm" : 10}), |
| | "fc" : Counter({"addmm" : 10}), |
| | "conv" : Counter({"conv" : 1000}), |
| | "act" : Counter() |
| | } |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model: nn.Module, |
| | inputs: Union[Tensor, Tuple[Tensor, ...]], |
| | ) -> None: |
| | super().__init__(model=model, inputs=inputs) |
| | self.set_op_handle(**_DEFAULT_SUPPORTED_ACT_OPS) |
| |
|
| | __init__.__doc__ = JitModelAnalysis.__init__.__doc__ |
| |
|
| |
|
| | def flop_count( |
| | model: nn.Module, |
| | inputs: Tuple[Any, ...], |
| | supported_ops: Optional[Dict[str, Handle]] = None, |
| | ) -> Tuple[DefaultDict[str, float], Counter[str]]: |
| | """Given a model and an input to the model, compute the per-operator Gflops |
| | of the given model. |
| | |
| | Adopted from |
| | https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py |
| | |
| | Args: |
| | model (nn.Module): The model to compute flop counts. |
| | inputs (tuple): Inputs that are passed to `model` to count flops. |
| | Inputs need to be in a tuple. |
| | supported_ops (dict(str,Callable) or None) : provide additional |
| | handlers for extra ops, or overwrite the existing handlers for |
| | convolution and matmul and einsum. The key is operator name and |
| | the value is a function that takes (inputs, outputs) of the op. |
| | We count one Multiply-Add as one FLOP. |
| | |
| | Returns: |
| | tuple[defaultdict, Counter]: A dictionary that records the number of |
| | gflops for each operation and a Counter that records the number of |
| | unsupported operations. |
| | """ |
| | if supported_ops is None: |
| | supported_ops = {} |
| | flop_counter = FlopAnalyzer(model, inputs).set_op_handle(**supported_ops) |
| | giga_flops = defaultdict(float) |
| | for op, flop in flop_counter.by_operator().items(): |
| | giga_flops[op] = flop / 1e9 |
| | return giga_flops, flop_counter.unsupported_ops() |
| |
|
| |
|
| | def activation_count( |
| | model: nn.Module, |
| | inputs: Tuple[Any, ...], |
| | supported_ops: Optional[Dict[str, Handle]] = None, |
| | ) -> Tuple[DefaultDict[str, float], Counter[str]]: |
| | """Given a model and an input to the model, compute the total number of |
| | activations of the model. |
| | |
| | Adopted from |
| | https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py |
| | |
| | Args: |
| | model (nn.Module): The model to compute activation counts. |
| | inputs (tuple): Inputs that are passed to `model` to count activations. |
| | Inputs need to be in a tuple. |
| | supported_ops (dict(str,Callable) or None) : provide additional |
| | handlers for extra ops, or overwrite the existing handlers for |
| | convolution and matmul. The key is operator name and the value |
| | is a function that takes (inputs, outputs) of the op. |
| | |
| | Returns: |
| | tuple[defaultdict, Counter]: A dictionary that records the number of |
| | activation (mega) for each operation and a Counter that records the |
| | number of unsupported operations. |
| | """ |
| | if supported_ops is None: |
| | supported_ops = {} |
| | act_counter = ActivationAnalyzer(model, |
| | inputs).set_op_handle(**supported_ops) |
| | mega_acts = defaultdict(float) |
| | for op, act in act_counter.by_operator().items(): |
| | mega_acts[op] = act / 1e6 |
| | return mega_acts, act_counter.unsupported_ops() |
| |
|
| |
|
| | def parameter_count(model: nn.Module) -> typing.DefaultDict[str, int]: |
| | """Count parameters of a model and its submodules. |
| | |
| | Adopted from |
| | https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py |
| | |
| | Args: |
| | model (nn.Module): the model to count parameters. |
| | |
| | Returns: |
| | dict[str, int]: the key is either a parameter name or a module name. |
| | The value is the number of elements in the parameter, or in all |
| | parameters of the module. The key "" corresponds to the total |
| | number of parameters of the model. |
| | """ |
| | count = defaultdict(int) |
| | for name, param in model.named_parameters(): |
| | size = param.numel() |
| | name = name.split('.') |
| | for k in range(0, len(name) + 1): |
| | prefix = '.'.join(name[:k]) |
| | count[prefix] += size |
| | return count |
| |
|
| |
|
| | def parameter_count_table(model: nn.Module, max_depth: int = 3) -> str: |
| | """Format the parameter count of the model (and its submodules or |
| | parameters) |
| | |
| | Adopted from |
| | https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py |
| | |
| | Args: |
| | model (nn.Module): the model to count parameters. |
| | max_depth (int): maximum depth to recursively print submodules or |
| | parameters |
| | |
| | Returns: |
| | str: the table to be printed |
| | """ |
| | count: typing.DefaultDict[str, int] = parameter_count(model) |
| | |
| | param_shape: typing.Dict[str, typing.Tuple] = { |
| | k: tuple(v.shape) |
| | for k, v in model.named_parameters() |
| | } |
| |
|
| | |
| | rows: typing.List[typing.Tuple] = [] |
| |
|
| | def format_size(x: int) -> str: |
| | if x > 1e8: |
| | return f'{x / 1e9:.1f}G' |
| | if x > 1e5: |
| | return f'{x / 1e6:.1f}M' |
| | if x > 1e2: |
| | return f'{x / 1e3:.1f}K' |
| | return str(x) |
| |
|
| | def fill(lvl: int, prefix: str) -> None: |
| | if lvl >= max_depth: |
| | return |
| | for name, v in count.items(): |
| | if name.count('.') == lvl and name.startswith(prefix): |
| | indent = ' ' * (lvl + 1) |
| | if name in param_shape: |
| | rows.append( |
| | (indent + name, indent + str(param_shape[name]))) |
| | else: |
| | rows.append((indent + name, indent + format_size(v))) |
| | fill(lvl + 1, name + '.') |
| |
|
| | rows.append(('model', format_size(count.pop('')))) |
| | fill(0, '') |
| |
|
| | table = Table( |
| | title=f'parameter count of {model.__class__.__name__}', box=box.ASCII2) |
| | table.add_column('name') |
| | table.add_column('#elements or shape') |
| |
|
| | for row in rows: |
| | table.add_row(*row) |
| |
|
| | console = Console() |
| | with console.capture() as capture: |
| | console.print(table, end='') |
| |
|
| | return capture.get() |
| |
|