|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
|
|
import numpy as np |
|
|
|
|
|
class FeatureMixerLayer(nn.Module): |
|
def __init__(self, in_dim, mlp_ratio=1): |
|
super().__init__() |
|
self.mix = nn.Sequential( |
|
nn.LayerNorm(in_dim), |
|
nn.Linear(in_dim, int(in_dim * mlp_ratio)), |
|
nn.ReLU(), |
|
nn.Linear(int(in_dim * mlp_ratio), in_dim), |
|
) |
|
|
|
for m in self.modules(): |
|
if isinstance(m, (nn.Linear)): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
|
|
def forward(self, x): |
|
return x + self.mix(x) |
|
|
|
|
|
class MixVPR(nn.Module): |
|
def __init__(self, |
|
in_channels=1024, |
|
in_h=20, |
|
in_w=20, |
|
out_channels=512, |
|
mix_depth=1, |
|
mlp_ratio=1, |
|
out_rows=4, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.in_h = in_h |
|
self.in_w = in_w |
|
self.in_channels = in_channels |
|
|
|
self.out_channels = out_channels |
|
self.out_rows = out_rows |
|
|
|
self.mix_depth = mix_depth |
|
self.mlp_ratio = mlp_ratio |
|
|
|
hw = in_h*in_w |
|
self.mix = nn.Sequential(*[ |
|
FeatureMixerLayer(in_dim=hw, mlp_ratio=mlp_ratio) |
|
for _ in range(self.mix_depth) |
|
]) |
|
self.channel_proj = nn.Linear(in_channels, out_channels) |
|
self.row_proj = nn.Linear(hw, out_rows) |
|
|
|
def forward(self, x): |
|
x = x.flatten(2) |
|
x = self.mix(x) |
|
x = x.permute(0, 2, 1) |
|
x = self.channel_proj(x) |
|
x = x.permute(0, 2, 1) |
|
x = self.row_proj(x) |
|
x = F.normalize(x.flatten(1), p=2, dim=1) |
|
return x |
|
|
|
|
|
|
|
|
|
def print_nb_params(m): |
|
model_parameters = filter(lambda p: p.requires_grad, m.parameters()) |
|
params = sum([np.prod(p.size()) for p in model_parameters]) |
|
print(f'Trainable parameters: {params/1e6:.3}M') |
|
|
|
|
|
def main(): |
|
x = torch.randn(1, 1024, 20, 20) |
|
agg = MixVPR( |
|
in_channels=1024, |
|
in_h=20, |
|
in_w=20, |
|
out_channels=1024, |
|
mix_depth=4, |
|
mlp_ratio=1, |
|
out_rows=4) |
|
|
|
print_nb_params(agg) |
|
output = agg(x) |
|
print(output.shape) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|