Spaces:
Build error
Build error
| """Layers implementing dimensionality reduction of a feature map""" | |
| import torch | |
| from torch import nn | |
| from ..utils import whitening | |
| class ConvDimReduction(nn.Conv2d): | |
| """Dimensionality reduction as a convolutional layer | |
| :param int input_dim: Network out_channels | |
| :param in dim: Whitening out_channels, for dimensionality reduction | |
| """ | |
| def __init__(self, input_dim, dim): | |
| super().__init__(input_dim, dim, (1, 1), padding=0, bias=True) | |
| def initialize_pca_whitening(self, des): | |
| """Initialize PCA whitening from given descriptors. Return tuple of shift and projection.""" | |
| m, P = whitening.pcawhitenlearn_shrinkage(des) | |
| m, P = m.T, P.T | |
| projection = torch.Tensor(P[:self.weight.shape[0], :]).unsqueeze(-1).unsqueeze(-1) | |
| self.weight.data = projection.to(self.weight.device) | |
| projected_shift = -torch.mm(torch.FloatTensor(P), torch.FloatTensor(m)).squeeze() | |
| self.bias.data = projected_shift[:self.weight.shape[0]].to(self.bias.device) | |
| return m.T, P.T | |