# Copyright (c) OpenMMLab. All rights reserved. import math from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from mmengine.fileio import list_from_file from mmengine.runner import autocast from mmengine.utils import is_seq_of from mmpretrain.models.losses import convert_to_one_hot from mmpretrain.registry import MODELS from mmpretrain.structures import DataSample from .cls_head import ClsHead class NormProduct(nn.Linear): """An enhanced linear layer with k clustering centers to calculate product between normalized input and linear weight. Args: in_features (int): size of each input sample. out_features (int): size of each output sample k (int): The number of clustering centers. Defaults to 1. bias (bool): Whether there is bias. If set to ``False``, the layer will not learn an additive bias. Defaults to ``True``. feature_norm (bool): Whether to normalize the input feature. Defaults to ``True``. weight_norm (bool):Whether to normalize the weight. Defaults to ``True``. """ def __init__(self, in_features: int, out_features: int, k=1, bias: bool = False, feature_norm: bool = True, weight_norm: bool = True): super().__init__(in_features, out_features * k, bias=bias) self.weight_norm = weight_norm self.feature_norm = feature_norm self.out_features = out_features self.k = k def forward(self, input: torch.Tensor) -> torch.Tensor: if self.feature_norm: input = F.normalize(input) if self.weight_norm: weight = F.normalize(self.weight) else: weight = self.weight cosine_all = F.linear(input, weight, self.bias) if self.k == 1: return cosine_all else: cosine_all = cosine_all.view(-1, self.out_features, self.k) cosine, _ = torch.max(cosine_all, dim=2) return cosine @MODELS.register_module() class ArcFaceClsHead(ClsHead): """ArcFace classifier head. A PyTorch implementation of paper `ArcFace: Additive Angular Margin Loss for Deep Face Recognition `_ and `Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web Faces `_ Example: To use ArcFace in config files. 1. use vanilla ArcFace .. code:: python mode = dict( backbone = xxx, neck = xxxx, head=dict( type='ArcFaceClsHead', num_classes=5000, in_channels=1024, loss = dict(type='CrossEntropyLoss', loss_weight=1.0), init_cfg=None), ) 2. use SubCenterArcFace with 3 sub-centers .. code:: python mode = dict( backbone = xxx, neck = xxxx, head=dict( type='ArcFaceClsHead', num_classes=5000, in_channels=1024, num_subcenters=3, loss = dict(type='CrossEntropyLoss', loss_weight=1.0), init_cfg=None), ) 3. use SubCenterArcFace With CountPowerAdaptiveMargins .. code:: python mode = dict( backbone = xxx, neck = xxxx, head=dict( type='ArcFaceClsHead', num_classes=5000, in_channels=1024, num_subcenters=3, loss = dict(type='CrossEntropyLoss', loss_weight=1.0), init_cfg=None), ) custom_hooks = [dict(type='SetAdaptiveMarginsHook')] Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. num_subcenters (int): Number of subcenters. Defaults to 1. scale (float): Scale factor of output logit. Defaults to 64.0. margins (float): The penalty margin. Could be the fllowing formats: - float: The margin, would be same for all the categories. - Sequence[float]: The category-based margins list. - str: A '.txt' file path which contains a list. Each line represents the margin of a category, and the number in the i-th row indicates the margin of the i-th class. Defaults to 0.5. easy_margin (bool): Avoid theta + m >= PI. Defaults to False. loss (dict): Config of classification loss. Defaults to ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. init_cfg (dict, optional): the config to control the initialization. Defaults to None. """ def __init__(self, num_classes: int, in_channels: int, num_subcenters: int = 1, scale: float = 64., margins: Optional[Union[float, Sequence[float], str]] = 0.50, easy_margin: bool = False, loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), init_cfg: Optional[dict] = None): super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg) if not isinstance(loss, nn.Module): loss = MODELS.build(loss) self.loss_module = loss assert num_subcenters >= 1 and num_classes >= 0 self.in_channels = in_channels self.num_classes = num_classes self.num_subcenters = num_subcenters self.scale = scale self.easy_margin = easy_margin self.norm_product = NormProduct(in_channels, num_classes, num_subcenters) if isinstance(margins, float): margins = [margins] * num_classes elif isinstance(margins, str) and margins.endswith('.txt'): margins = [float(item) for item in list_from_file(margins)] else: assert is_seq_of(list(margins), (float, int)), ( 'the attribute `margins` in ``ArcFaceClsHead`` should be a ' ' float, a Sequence of float, or a ".txt" file path.') assert len(margins) == num_classes, \ 'The length of margins must be equal with num_classes.' self.register_buffer( 'margins', torch.tensor(margins).float(), persistent=False) # To make `phi` monotonic decreasing, refers to # https://github.com/deepinsight/insightface/issues/108 sinm_m = torch.sin(math.pi - self.margins) * self.margins threshold = torch.cos(math.pi - self.margins) self.register_buffer('sinm_m', sinm_m, persistent=False) self.register_buffer('threshold', threshold, persistent=False) def set_margins(self, margins: Union[Sequence[float], float]) -> None: """set margins of arcface head. Args: margins (Union[Sequence[float], float]): The marigins. """ if isinstance(margins, float): margins = [margins] * self.num_classes assert is_seq_of( list(margins), float) and (len(margins) == self.num_classes), ( f'margins must be Sequence[Union(float, int)], get {margins}') self.margins = torch.tensor( margins, device=self.margins.device, dtype=torch.float32) self.sinm_m = torch.sin(self.margins) * self.margins self.threshold = -torch.cos(self.margins) def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: """The process before the final classification head. The input ``feats`` is a tuple of tensor, and each tensor is the feature of a backbone stage. In ``ArcFaceHead``, we just obtain the feature of the last stage. """ # The ArcFaceHead doesn't have other module, just return after # unpacking. return feats[-1] def _get_logit_with_margin(self, pre_logits, target): """add arc margin to the cosine in target index. The target must be in index format. """ assert target.dim() == 1 or ( target.dim() == 2 and target.shape[1] == 1), \ 'The target must be in index format.' cosine = self.norm_product(pre_logits) phi = torch.cos(torch.acos(cosine) + self.margins) if self.easy_margin: # when cosine>0, choose phi # when cosine<=0, choose cosine phi = torch.where(cosine > 0, phi, cosine) else: # when cos>th, choose phi # when cos<=th, choose cosine-mm phi = torch.where(cosine > self.threshold, phi, cosine - self.sinm_m) target = convert_to_one_hot(target, self.num_classes) output = target * phi + (1 - target) * cosine return output def forward(self, feats: Tuple[torch.Tensor], target: Optional[torch.Tensor] = None) -> torch.Tensor: """The forward process.""" # Disable AMP with autocast(enabled=False): pre_logits = self.pre_logits(feats) if target is None: # when eval, logit is the cosine between W and pre_logits; # cos(theta_yj) = (x/||x||) * (W/||W||) logit = self.norm_product(pre_logits) else: # when training, add a margin to the pre_logits where target is # True, then logit is the cosine between W and new pre_logits logit = self._get_logit_with_margin(pre_logits, target) return self.scale * logit def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], **kwargs) -> dict: """Calculate losses from the classification score. Args: feats (tuple[Tensor]): The features extracted from the backbone. Multiple stage inputs are acceptable but only the last stage will be used to classify. The shape of every item should be ``(num_samples, num_classes)``. data_samples (List[DataSample]): The annotation data of every samples. **kwargs: Other keyword arguments to forward the loss module. Returns: dict[str, Tensor]: a dictionary of loss components """ # Unpack data samples and pack targets label_target = torch.cat([i.gt_label for i in data_samples]) if 'gt_score' in data_samples[0]: # Batch augmentation may convert labels to one-hot format scores. target = torch.stack([i.gt_score for i in data_samples]) else: target = label_target # the index format target would be used cls_score = self(feats, label_target) # compute loss losses = dict() loss = self.loss_module( cls_score, target, avg_factor=cls_score.size(0), **kwargs) losses['loss'] = loss return losses