SuperFeatures / how /layers /dim_reduction.py
YannisK's picture
temp state
32408ed
raw
history blame
1.05 kB
"""Layers implementing dimensionality reduction of a feature map"""
import torch
from torch import nn
from ..utils import whitening
class ConvDimReduction(nn.Conv2d):
"""Dimensionality reduction as a convolutional layer
:param int input_dim: Network out_channels
:param in dim: Whitening out_channels, for dimensionality reduction
"""
def __init__(self, input_dim, dim):
super().__init__(input_dim, dim, (1, 1), padding=0, bias=True)
def initialize_pca_whitening(self, des):
"""Initialize PCA whitening from given descriptors. Return tuple of shift and projection."""
m, P = whitening.pcawhitenlearn_shrinkage(des)
m, P = m.T, P.T
projection = torch.Tensor(P[:self.weight.shape[0], :]).unsqueeze(-1).unsqueeze(-1)
self.weight.data = projection.to(self.weight.device)
projected_shift = -torch.mm(torch.FloatTensor(P), torch.FloatTensor(m)).squeeze()
self.bias.data = projected_shift[:self.weight.shape[0]].to(self.bias.device)
return m.T, P.T