# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Tuple import torch import torch.nn as nn from mmpretrain.registry import MODELS from .cls_head import ClsHead @MODELS.register_module() class LinearClsHead(ClsHead): """Linear classifier head. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. loss (dict): Config of classification loss. Defaults to ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. cal_acc (bool): Whether to calculate accuracy during training. If you use batch augmentations like Mixup and CutMix during training, it is pointless to calculate accuracy. Defaults to False. init_cfg (dict, optional): the config to control the initialization. Defaults to ``dict(type='Normal', layer='Linear', std=0.01)``. """ def __init__(self, num_classes: int, in_channels: int, init_cfg: Optional[dict] = dict( type='Normal', layer='Linear', std=0.01), **kwargs): super(LinearClsHead, self).__init__(init_cfg=init_cfg, **kwargs) self.in_channels = in_channels self.num_classes = num_classes if self.num_classes <= 0: raise ValueError( f'num_classes={num_classes} must be a positive integer') self.fc = nn.Linear(self.in_channels, self.num_classes) def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: """The process before the final classification head. The input ``feats`` is a tuple of tensor, and each tensor is the feature of a backbone stage. In ``LinearClsHead``, we just obtain the feature of the last stage. """ # The LinearClsHead doesn't have other module, just return after # unpacking. return feats[-1] def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: """The forward process.""" pre_logits = self.pre_logits(feats) # The final classification head. cls_score = self.fc(pre_logits) return cls_score