# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han # International Conference on Computer Vision (ICCV), 2023 import torch import torch.nn as nn from torch.nn.modules.batchnorm import _BatchNorm from src.efficientvit.models.utils import build_kwargs_from_config __all__ = ["LayerNorm2d", "build_norm", "reset_bn", "set_norm_eps"] class LayerNorm2d(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: out = x - torch.mean(x, dim=1, keepdim=True) out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps) if self.elementwise_affine: out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) return out # register normalization function here REGISTERED_NORM_DICT: dict[str, type] = { "bn2d": nn.BatchNorm2d, "ln": nn.LayerNorm, "ln2d": LayerNorm2d, } def build_norm(name="bn2d", num_features=None, **kwargs) -> nn.Module or None: if name in ["ln", "ln2d"]: kwargs["normalized_shape"] = num_features else: kwargs["num_features"] = num_features if name in REGISTERED_NORM_DICT: norm_cls = REGISTERED_NORM_DICT[name] args = build_kwargs_from_config(kwargs, norm_cls) return norm_cls(**args) else: return None def reset_bn( model: nn.Module, data_loader: list, sync=True, progress_bar=False, ) -> None: import copy import torch.nn.functional as F from tqdm import tqdm from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor from efficientvit.models.utils import get_device, list_join bn_mean = {} bn_var = {} tmp_model = copy.deepcopy(model) for name, m in tmp_model.named_modules(): if isinstance(m, _BatchNorm): bn_mean[name] = AverageMeter(is_distributed=False) bn_var[name] = AverageMeter(is_distributed=False) def new_forward(bn, mean_est, var_est): def lambda_forward(x): x = x.contiguous() if sync: batch_mean = ( x.mean(0, keepdim=True) .mean(2, keepdim=True) .mean(3, keepdim=True) ) # 1, C, 1, 1 batch_mean = sync_tensor(batch_mean, reduce="cat") batch_mean = torch.mean(batch_mean, dim=0, keepdim=True) batch_var = (x - batch_mean) * (x - batch_mean) batch_var = ( batch_var.mean(0, keepdim=True) .mean(2, keepdim=True) .mean(3, keepdim=True) ) batch_var = sync_tensor(batch_var, reduce="cat") batch_var = torch.mean(batch_var, dim=0, keepdim=True) else: batch_mean = ( x.mean(0, keepdim=True) .mean(2, keepdim=True) .mean(3, keepdim=True) ) # 1, C, 1, 1 batch_var = (x - batch_mean) * (x - batch_mean) batch_var = ( batch_var.mean(0, keepdim=True) .mean(2, keepdim=True) .mean(3, keepdim=True) ) batch_mean = torch.squeeze(batch_mean) batch_var = torch.squeeze(batch_var) mean_est.update(batch_mean.data, x.size(0)) var_est.update(batch_var.data, x.size(0)) # bn forward using calculated mean & var _feature_dim = batch_mean.shape[0] return F.batch_norm( x, batch_mean, batch_var, bn.weight[:_feature_dim], bn.bias[:_feature_dim], False, 0.0, bn.eps, ) return lambda_forward m.forward = new_forward(m, bn_mean[name], bn_var[name]) # skip if there is no batch normalization layers in the network if len(bn_mean) == 0: return tmp_model.eval() with torch.no_grad(): with tqdm( total=len(data_loader), desc="reset bn", disable=not progress_bar or not is_master(), ) as t: for images in data_loader: images = images.to(get_device(tmp_model)) tmp_model(images) t.set_postfix( { "bs": images.size(0), "res": list_join(images.shape[-2:], "x"), } ) t.update() for name, m in model.named_modules(): if name in bn_mean and bn_mean[name].count > 0: feature_dim = bn_mean[name].avg.size(0) assert isinstance(m, _BatchNorm) m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) m.running_var.data[:feature_dim].copy_(bn_var[name].avg) def set_norm_eps(model: nn.Module, eps: float or None = None) -> None: for m in model.modules(): if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): if eps is not None: m.eps = eps