# Copyright (c) Facebook, Inc. and its affiliates. import numpy as np import torch from torch import nn from torch.nn import functional as F from detectron2.config import configurable from detectron2.layers import Linear, ShapeSpec class ZeroShotClassifier(nn.Module): @configurable def __init__( self, input_shape: ShapeSpec, *, num_classes: int, zs_weight_path: str, zs_weight_dim: int = 512, use_bias: float = 0.0, norm_weight: bool = True, norm_temperature: float = 50.0, ): super().__init__() if isinstance(input_shape, int): # some backward compatibility input_shape = ShapeSpec(channels=input_shape) input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1) self.norm_weight = norm_weight self.norm_temperature = norm_temperature self.use_bias = use_bias < 0 if self.use_bias: self.cls_bias = nn.Parameter(torch.ones(1) * use_bias) self.linear = nn.Linear(input_size, zs_weight_dim) if zs_weight_path == 'rand': zs_weight = torch.randn((zs_weight_dim, num_classes)) nn.init.normal_(zs_weight, std=0.01) else: zs_weight = torch.tensor( np.load(zs_weight_path), dtype=torch.float32).permute(1, 0).contiguous() # D x C zs_weight = torch.cat( [zs_weight, zs_weight.new_zeros((zs_weight_dim, 1))], dim=1) # D x (C + 1) if self.norm_weight: zs_weight = F.normalize(zs_weight, p=2, dim=0) if zs_weight_path == 'rand': self.zs_weight = nn.Parameter(zs_weight) else: self.register_buffer('zs_weight', zs_weight) assert self.zs_weight.shape[1] == num_classes + 1, self.zs_weight.shape @classmethod def from_config(cls, cfg, input_shape): return { 'input_shape': input_shape, 'num_classes': cfg.MODEL.ROI_HEADS.NUM_CLASSES, 'zs_weight_path': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH, 'zs_weight_dim': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM, 'use_bias': cfg.MODEL.ROI_BOX_HEAD.USE_BIAS, 'norm_weight': cfg.MODEL.ROI_BOX_HEAD.NORM_WEIGHT, 'norm_temperature': cfg.MODEL.ROI_BOX_HEAD.NORM_TEMP, } def forward(self, x, classifier=None): ''' Inputs: x: B x D' classifier_info: (C', C' x D) ''' x = self.linear(x) if classifier is not None: zs_weight = classifier.permute(1, 0).contiguous() # D x C' zs_weight = F.normalize(zs_weight, p=2, dim=0) \ if self.norm_weight else zs_weight else: zs_weight = self.zs_weight if self.norm_weight: x = self.norm_temperature * F.normalize(x, p=2, dim=1) x = torch.mm(x, zs_weight) if self.use_bias: x = x + self.cls_bias return x