def set_bn_eval(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.eval() def set_bn_non_trainable(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.weight.requires_grad = False m.bias.requires_grad = False def freeze_bn_statistics(model): """freeze the statistic mean and variance in BN Args: model (nn.Module): The model to be freezed statistics. """ model.apply(set_bn_eval) def freeze_bn_parameters(model): """ Args: model (nn.Module): The model to be freezed statistics. Returns: TODO """ model.apply(set_bn_non_trainable)