| | |
| | |
| | |
| | |
| | |
| |
|
| | from operator import attrgetter |
| | from typing import List, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm, |
| | conv: nn.modules.conv._ConvNd, |
| | x: torch.Tensor): |
| | """Code borrowed from mmcv 2.0.1, so that this feature can be used for old |
| | mmcv versions. |
| | |
| | Implementation based on https://arxiv.org/abs/2305.11624 |
| | "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" |
| | It leverages the associative law between convolution and affine transform, |
| | i.e., normalize (weight conv feature) = (normalize weight) conv feature. |
| | It works for Eval mode of ConvBN blocks during validation, and can be used |
| | for training as well. It reduces memory and computation cost. |
| | Args: |
| | bn (_BatchNorm): a BatchNorm module. |
| | conv (nn._ConvNd): a conv module |
| | x (torch.Tensor): Input feature map. |
| | """ |
| | |
| | |
| | weight_on_the_fly = conv.weight |
| | if conv.bias is not None: |
| | bias_on_the_fly = conv.bias |
| | else: |
| | bias_on_the_fly = torch.zeros_like(bn.running_var) |
| |
|
| | if bn.weight is not None: |
| | bn_weight = bn.weight |
| | else: |
| | bn_weight = torch.ones_like(bn.running_var) |
| |
|
| | if bn.bias is not None: |
| | bn_bias = bn.bias |
| | else: |
| | bn_bias = torch.zeros_like(bn.running_var) |
| |
|
| | |
| | weight_coeff = torch.rsqrt(bn.running_var + |
| | bn.eps).reshape([-1] + [1] * |
| | (len(conv.weight.shape) - 1)) |
| | |
| | coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff |
| |
|
| | |
| | weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly |
| | |
| | bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\ |
| | (bias_on_the_fly - bn.running_mean) |
| |
|
| | return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly) |
| |
|
| |
|
| | def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm, |
| | conv: nn.modules.conv._ConvNd, |
| | x: torch.Tensor): |
| | """This function controls whether to use `efficient_conv_bn_eval_forward`. |
| | |
| | If the following `bn` is in `eval` mode, then we turn on the special |
| | `efficient_conv_bn_eval_forward`. |
| | """ |
| | if not bn.training: |
| | |
| | output = efficient_conv_bn_eval_forward(bn, conv, x) |
| | return output |
| | else: |
| | conv_out = conv._conv_forward(x, conv.weight, conv.bias) |
| | return bn(conv_out) |
| |
|
| |
|
| | def efficient_conv_bn_eval_graph_transform(fx_model): |
| | """Find consecutive conv+bn calls in the graph, inplace modify the graph |
| | with the fused operation.""" |
| | modules = dict(fx_model.named_modules()) |
| |
|
| | patterns = [(torch.nn.modules.conv._ConvNd, |
| | torch.nn.modules.batchnorm._BatchNorm)] |
| |
|
| | pairs = [] |
| | |
| | for node in fx_model.graph.nodes: |
| | |
| | if node.op != 'call_module': |
| | continue |
| | target_module = modules[node.target] |
| | found_pair = False |
| | for conv_class, bn_class in patterns: |
| | if isinstance(target_module, bn_class): |
| | source_module = modules[node.args[0].target] |
| | if isinstance(source_module, conv_class): |
| | found_pair = True |
| | |
| | if not found_pair or len(node.args[0].users) > 1: |
| | continue |
| |
|
| | |
| | conv_node = node.args[0] |
| | bn_node = node |
| | pairs.append([conv_node, bn_node]) |
| |
|
| | for conv_node, bn_node in pairs: |
| | |
| | fx_model.graph.inserting_before(conv_node) |
| | |
| | |
| | |
| | |
| | conv_get_node = fx_model.graph.create_node( |
| | op='get_attr', target=conv_node.target, name='get_conv') |
| | bn_get_node = fx_model.graph.create_node( |
| | op='get_attr', target=bn_node.target, name='get_bn') |
| | |
| | args = (bn_get_node, conv_get_node, conv_node.args[0]) |
| | |
| | new_node = fx_model.graph.create_node( |
| | op='call_function', |
| | target=efficient_conv_bn_eval_control, |
| | args=args, |
| | name='efficient_conv_bn_eval') |
| | |
| | |
| | bn_node.replace_all_uses_with(new_node) |
| | |
| | |
| | fx_model.graph.erase_node(bn_node) |
| | fx_model.graph.erase_node(conv_node) |
| |
|
| | |
| | fx_model.graph.lint() |
| | fx_model.recompile() |
| |
|
| |
|
| | def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module): |
| | import torch.fx as fx |
| |
|
| | |
| | |
| | |
| | |
| | |
| | fx_model: fx.GraphModule = fx.symbolic_trace(model) |
| | efficient_conv_bn_eval_graph_transform(fx_model) |
| | model.forward = fx_model.forward |
| |
|
| |
|
| | def turn_on_efficient_conv_bn_eval(model: torch.nn.Module, |
| | modules: Union[List[str], str]): |
| | if isinstance(modules, str): |
| | modules = [modules] |
| | for module_name in modules: |
| | module = attrgetter(module_name)(model) |
| | turn_on_efficient_conv_bn_eval_for_single_model(module) |
| |
|