Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class NoDimReduction(nn.Module): | |
def __init__(self, full_channels, reduced_channels): | |
super().__init__() | |
assert full_channels == reduced_channels | |
def forward(self, features): | |
return features | |
class MlpDimReduction(nn.Module): | |
def __init__(self, full_channels, reduced_channels, latent_channels): | |
super().__init__() | |
self.linear_in = nn.Linear(reduced_channels, latent_channels) | |
self.linear_out = nn.Linear(latent_channels, full_channels) | |
self.relu = nn.ReLU() | |
def transform_expand(self, features): | |
latent = self.relu(self.linear_in(features)) | |
output = self.linear_out(latent) | |
return F.normalize(output, dim=-1) | |
class OrthogonalLinearDimReduction(nn.Module): | |
def __init__(self, full_channels, reduced_channels): | |
super().__init__() | |
self.bias = torch.nn.Parameter(torch.zeros(full_channels)) | |
self.weights = torch.nn.Parameter(torch.eye(full_channels, reduced_channels)) | |
def transform_expand(self, features): | |
output = features @ self.weights.transpose(0, 1) + self.bias | |
return F.normalize(output, dim=-1) | |