""" Copyright (C) 2019 Sovrasov V. - All Rights Reserved * You may use, distribute and modify this code under the * terms of the MIT license. * You should have received a copy of the MIT license with * this file. If not visit https://opensource.org/licenses/MIT """ import sys from functools import partial import numpy as np import torch import torch.nn as nn from maskrcnn_benchmark.layers import * def get_model_complexity_info( model, input_res, print_per_layer_stat=True, as_strings=True, input_constructor=None, ost=sys.stdout, verbose=False, ignore_modules=[], custom_modules_hooks={}, ): assert type(input_res) is tuple assert len(input_res) >= 1 assert isinstance(model, nn.Module) global CUSTOM_MODULES_MAPPING CUSTOM_MODULES_MAPPING = custom_modules_hooks flops_model = add_flops_counting_methods(model) flops_model.eval() flops_model.start_flops_count(ost=ost, verbose=verbose, ignore_list=ignore_modules) if input_constructor: input = input_constructor(input_res) _ = flops_model(**input) else: try: batch = torch.ones(()).new_empty( (1, *input_res), dtype=next(flops_model.parameters()).dtype, device=next(flops_model.parameters()).device, ) except StopIteration: batch = torch.ones(()).new_empty((1, *input_res)) _ = flops_model(batch) flops_count, params_count = flops_model.compute_average_flops_cost() if print_per_layer_stat: print_model_with_flops(flops_model, flops_count, params_count, ost=ost) flops_model.stop_flops_count() CUSTOM_MODULES_MAPPING = {} if as_strings: return flops_to_string(flops_count), params_to_string(params_count) return flops_count, params_count def flops_to_string(flops, units="GMac", precision=2): if units is None: if flops // 10**9 > 0: return str(round(flops / 10.0**9, precision)) + " GMac" elif flops // 10**6 > 0: return str(round(flops / 10.0**6, precision)) + " MMac" elif flops // 10**3 > 0: return str(round(flops / 10.0**3, precision)) + " KMac" else: return str(flops) + " Mac" else: if units == "GMac": return str(round(flops / 10.0**9, precision)) + " " + units elif units == "MMac": return str(round(flops / 10.0**6, precision)) + " " + units elif units == "KMac": return str(round(flops / 10.0**3, precision)) + " " + units else: return str(flops) + " Mac" def params_to_string(params_num, units=None, precision=2): if units is None: if params_num // 10**6 > 0: return str(round(params_num / 10**6, 2)) + " M" elif params_num // 10**3: return str(round(params_num / 10**3, 2)) + " k" else: return str(params_num) else: if units == "M": return str(round(params_num / 10.0**6, precision)) + " " + units elif units == "K": return str(round(params_num / 10.0**3, precision)) + " " + units else: return str(params_num) def accumulate_flops(self): if is_supported_instance(self): return self.__flops__ else: sum = 0 for m in self.children(): sum += m.accumulate_flops() return sum def print_model_with_flops(model, total_flops, total_params, units="GMac", precision=3, ost=sys.stdout): def accumulate_params(self): if is_supported_instance(self): return self.__params__ else: sum = 0 for m in self.children(): sum += m.accumulate_params() return sum def flops_repr(self): accumulated_params_num = self.accumulate_params() accumulated_flops_cost = self.accumulate_flops() / model.__batch_counter__ return ", ".join( [ params_to_string(accumulated_params_num, units="M", precision=precision), "{:.3%} Params".format(accumulated_params_num / total_params), flops_to_string(accumulated_flops_cost, units=units, precision=precision), "{:.3%} MACs".format(accumulated_flops_cost / total_flops), self.original_extra_repr(), ] ) def add_extra_repr(m): m.accumulate_flops = accumulate_flops.__get__(m) m.accumulate_params = accumulate_params.__get__(m) flops_extra_repr = flops_repr.__get__(m) if m.extra_repr != flops_extra_repr: m.original_extra_repr = m.extra_repr m.extra_repr = flops_extra_repr assert m.extra_repr != m.original_extra_repr def del_extra_repr(m): if hasattr(m, "original_extra_repr"): m.extra_repr = m.original_extra_repr del m.original_extra_repr if hasattr(m, "accumulate_flops"): del m.accumulate_flops model.apply(add_extra_repr) print(repr(model), file=ost) model.apply(del_extra_repr) def get_model_parameters_number(model): params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) return params_num def add_flops_counting_methods(net_main_module): # adding additional methods to the existing module object, # this is done this way so that each function has access to self object net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) net_main_module.reset_flops_count() return net_main_module def compute_average_flops_cost(self): """ A method that will be available after add_flops_counting_methods() is called on a desired net object. Returns current mean flops consumption per image. """ for m in self.modules(): m.accumulate_flops = accumulate_flops.__get__(m) flops_sum = self.accumulate_flops() for m in self.modules(): if hasattr(m, "accumulate_flops"): del m.accumulate_flops params_sum = get_model_parameters_number(self) return flops_sum / self.__batch_counter__, params_sum def start_flops_count(self, **kwargs): """ A method that will be available after add_flops_counting_methods() is called on a desired net object. Activates the computation of mean flops consumption per image. Call it before you run the network. """ add_batch_counter_hook_function(self) seen_types = set() def add_flops_counter_hook_function(module, ost, verbose, ignore_list): if type(module) in ignore_list: seen_types.add(type(module)) if is_supported_instance(module): module.__params__ = 0 elif is_supported_instance(module): if hasattr(module, "__flops_handle__"): return if type(module) in CUSTOM_MODULES_MAPPING: handle = module.register_forward_hook(CUSTOM_MODULES_MAPPING[type(module)]) elif getattr(module, "compute_macs", False): handle = module.register_forward_hook(module.compute_macs) else: handle = module.register_forward_hook(MODULES_MAPPING[type(module)]) module.__flops_handle__ = handle seen_types.add(type(module)) else: if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and not type(module) in seen_types: print("Warning: module " + type(module).__name__ + " is treated as a zero-op.", file=ost) seen_types.add(type(module)) self.apply(partial(add_flops_counter_hook_function, **kwargs)) def stop_flops_count(self): """ A method that will be available after add_flops_counting_methods() is called on a desired net object. Stops computing the mean flops consumption per image. Call whenever you want to pause the computation. """ remove_batch_counter_hook_function(self) self.apply(remove_flops_counter_hook_function) def reset_flops_count(self): """ A method that will be available after add_flops_counting_methods() is called on a desired net object. Resets statistics computed so far. """ add_batch_counter_variables_or_reset(self) self.apply(add_flops_counter_variable_or_reset) # ---- Internal functions def empty_flops_counter_hook(module, input, output): module.__flops__ += 0 def upsample_flops_counter_hook(module, input, output): output_size = output[0] batch_size = output_size.shape[0] output_elements_count = batch_size for val in output_size.shape[1:]: output_elements_count *= val module.__flops__ += int(output_elements_count) def relu_flops_counter_hook(module, input, output): active_elements_count = output.numel() module.__flops__ += int(active_elements_count) def linear_flops_counter_hook(module, input, output): input = input[0] # pytorch checks dimensions, so here we don't care much output_last_dim = output.shape[-1] bias_flops = output_last_dim if module.bias is not None else 0 module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops) def pool_flops_counter_hook(module, input, output): input = input[0] module.__flops__ += int(np.prod(input.shape)) def bn_flops_counter_hook(module, input, output): input = input[0] batch_flops = np.prod(input.shape) if module.affine: batch_flops *= 2 module.__flops__ += int(batch_flops) def conv_flops_counter_hook(conv_module, input, output): # Can have multiple inputs, getting the first one input = input[0] batch_size = input.shape[0] output_dims = list(output.shape[2:]) kernel_dims = list(conv_module.kernel_size) in_channels = conv_module.in_channels out_channels = conv_module.out_channels groups = conv_module.groups filters_per_channel = out_channels // groups conv_per_position_flops = int(np.prod(kernel_dims)) * in_channels * filters_per_channel active_elements_count = batch_size * int(np.prod(output_dims)) overall_conv_flops = conv_per_position_flops * active_elements_count bias_flops = 0 if conv_module.bias is not None: bias_flops = out_channels * active_elements_count overall_flops = overall_conv_flops + bias_flops conv_module.__flops__ += int(overall_flops) def batch_counter_hook(module, input, output): batch_size = 1 if len(input) > 0: # Can have multiple inputs, getting the first one input = input[0] batch_size = len(input) else: pass print("Warning! No positional inputs found for a module," " assuming batch size is 1.") module.__batch_counter__ += batch_size def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): # matrix matrix mult ih state and internal state flops += w_ih.shape[0] * w_ih.shape[1] # matrix matrix mult hh state and internal state flops += w_hh.shape[0] * w_hh.shape[1] if isinstance(rnn_module, (nn.RNN, nn.RNNCell)): # add both operations flops += rnn_module.hidden_size elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)): # hadamard of r flops += rnn_module.hidden_size # adding operations from both states flops += rnn_module.hidden_size * 3 # last two hadamard product and add flops += rnn_module.hidden_size * 3 elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)): # adding operations from both states flops += rnn_module.hidden_size * 4 # two hadamard product and add for C state flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size # final hadamard flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size return flops def rnn_flops_counter_hook(rnn_module, input, output): """ Takes into account batch goes at first position, contrary to pytorch common rule (but actually it doesn't matter). IF sigmoid and tanh are made hard, only a comparison FLOPS should be accurate """ flops = 0 # input is a tuple containing a sequence to process and (optionally) hidden state inp = input[0] batch_size = inp.shape[0] seq_length = inp.shape[1] num_layers = rnn_module.num_layers for i in range(num_layers): w_ih = rnn_module.__getattr__("weight_ih_l" + str(i)) w_hh = rnn_module.__getattr__("weight_hh_l" + str(i)) if i == 0: input_size = rnn_module.input_size else: input_size = rnn_module.hidden_size flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size) if rnn_module.bias: b_ih = rnn_module.__getattr__("bias_ih_l" + str(i)) b_hh = rnn_module.__getattr__("bias_hh_l" + str(i)) flops += b_ih.shape[0] + b_hh.shape[0] flops *= batch_size flops *= seq_length if rnn_module.bidirectional: flops *= 2 rnn_module.__flops__ += int(flops) def rnn_cell_flops_counter_hook(rnn_cell_module, input, output): flops = 0 inp = input[0] batch_size = inp.shape[0] w_ih = rnn_cell_module.__getattr__("weight_ih") w_hh = rnn_cell_module.__getattr__("weight_hh") input_size = inp.shape[1] flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size) if rnn_cell_module.bias: b_ih = rnn_cell_module.__getattr__("bias_ih") b_hh = rnn_cell_module.__getattr__("bias_hh") flops += b_ih.shape[0] + b_hh.shape[0] flops *= batch_size rnn_cell_module.__flops__ += int(flops) def add_batch_counter_variables_or_reset(module): module.__batch_counter__ = 0 def add_batch_counter_hook_function(module): if hasattr(module, "__batch_counter_handle__"): return handle = module.register_forward_hook(batch_counter_hook) module.__batch_counter_handle__ = handle def remove_batch_counter_hook_function(module): if hasattr(module, "__batch_counter_handle__"): module.__batch_counter_handle__.remove() del module.__batch_counter_handle__ def add_flops_counter_variable_or_reset(module): if is_supported_instance(module): if hasattr(module, "__flops__") or hasattr(module, "__params__"): print( "Warning: variables __flops__ or __params__ are already " "defined for the module" + type(module).__name__ + " ptflops can affect your code!" ) module.__flops__ = 0 module.__params__ = get_model_parameters_number(module) CUSTOM_MODULES_MAPPING = {} MODULES_MAPPING = { # convolutions nn.Conv1d: conv_flops_counter_hook, nn.Conv2d: conv_flops_counter_hook, nn.Conv3d: conv_flops_counter_hook, Conv2d: conv_flops_counter_hook, ModulatedDeformConv: conv_flops_counter_hook, # activations nn.ReLU: relu_flops_counter_hook, nn.PReLU: relu_flops_counter_hook, nn.ELU: relu_flops_counter_hook, nn.LeakyReLU: relu_flops_counter_hook, nn.ReLU6: relu_flops_counter_hook, # poolings nn.MaxPool1d: pool_flops_counter_hook, nn.AvgPool1d: pool_flops_counter_hook, nn.AvgPool2d: pool_flops_counter_hook, nn.MaxPool2d: pool_flops_counter_hook, nn.MaxPool3d: pool_flops_counter_hook, nn.AvgPool3d: pool_flops_counter_hook, nn.AdaptiveMaxPool1d: pool_flops_counter_hook, nn.AdaptiveAvgPool1d: pool_flops_counter_hook, nn.AdaptiveMaxPool2d: pool_flops_counter_hook, nn.AdaptiveAvgPool2d: pool_flops_counter_hook, nn.AdaptiveMaxPool3d: pool_flops_counter_hook, nn.AdaptiveAvgPool3d: pool_flops_counter_hook, # BNs nn.BatchNorm1d: bn_flops_counter_hook, nn.BatchNorm2d: bn_flops_counter_hook, nn.BatchNorm3d: bn_flops_counter_hook, nn.GroupNorm: bn_flops_counter_hook, # FC nn.Linear: linear_flops_counter_hook, # Upscale nn.Upsample: upsample_flops_counter_hook, # Deconvolution nn.ConvTranspose1d: conv_flops_counter_hook, nn.ConvTranspose2d: conv_flops_counter_hook, nn.ConvTranspose3d: conv_flops_counter_hook, ConvTranspose2d: conv_flops_counter_hook, # RNN nn.RNN: rnn_flops_counter_hook, nn.GRU: rnn_flops_counter_hook, nn.LSTM: rnn_flops_counter_hook, nn.RNNCell: rnn_cell_flops_counter_hook, nn.LSTMCell: rnn_cell_flops_counter_hook, nn.GRUCell: rnn_cell_flops_counter_hook, } def is_supported_instance(module): if ( type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING or getattr(module, "compute_macs", False) ): return True return False def remove_flops_counter_hook_function(module): if is_supported_instance(module): if hasattr(module, "__flops_handle__"): module.__flops_handle__.remove() del module.__flops_handle__