Realcat
fix: eloftr
63f3cf2
raw
history blame
4.11 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> layers
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 29/01/2024 14:46
=================================================='''
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from einops import rearrange
def MLP(channels: list, do_bn=True, ac_fn='relu', norm_fn='bn'):
""" Multi-layer perceptron """
n = len(channels)
layers = []
for i in range(1, n):
layers.append(
nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
if i < (n - 1):
if norm_fn == 'in':
layers.append(nn.InstanceNorm1d(channels[i], eps=1e-3))
elif norm_fn == 'bn':
layers.append(nn.BatchNorm1d(channels[i], eps=1e-3))
if ac_fn == 'relu':
layers.append(nn.ReLU())
elif ac_fn == 'gelu':
layers.append(nn.GELU())
elif ac_fn == 'lrelu':
layers.append(nn.LeakyReLU(negative_slope=0.1))
# if norm_fn == 'ln':
# layers.append(nn.LayerNorm(channels[i]))
return nn.Sequential(*layers)
class MultiHeadedAttention(nn.Module):
def __init__(self, num_heads: int, d_model: int):
super().__init__()
assert d_model % num_heads == 0
self.dim = d_model // num_heads
self.num_heads = num_heads
self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
def forward(self, query, key, value, M=None):
'''
:param query: [B, D, N]
:param key: [B, D, M]
:param value: [B, D, M]
:param M: [B, N, M]
:return:
'''
batch_dim = query.size(0)
query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
for l, x in zip(self.proj, (query, key, value))] # [B, D, NH, N]
dim = query.shape[1]
scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5
if M is not None:
# print('M: ', scores.shape, M.shape, torch.sum(M, dim=2))
# scores = scores * M[:, None, :, :].expand_as(scores)
# with torch.no_grad():
mask = (1 - M[:, None, :, :]).repeat(1, scores.shape[1], 1, 1).bool() # [B, H, N, M]
scores = scores.masked_fill(mask, -torch.finfo(scores.dtype).max)
prob = F.softmax(scores, dim=-1) # * (~mask).float() # * mask.float()
else:
prob = F.softmax(scores, dim=-1)
x = torch.einsum('bhnm,bdhm->bdhn', prob, value)
self.prob = prob
out = self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1))
return out
class AttentionalPropagation(nn.Module):
def __init__(self, feature_dim: int, num_heads: int, ac_fn='relu', norm_fn='bn'):
super().__init__()
self.attn = MultiHeadedAttention(num_heads, feature_dim)
self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim], ac_fn=ac_fn, norm_fn=norm_fn)
nn.init.constant_(self.mlp[-1].bias, 0.0)
def forward(self, x, source, M=None):
message = self.attn(x, source, source, M=M)
self.prob = self.attn.prob
out = self.mlp(torch.cat([x, message], dim=1))
return out
class KeypointEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, input_dim, feature_dim, layers, ac_fn='relu', norm_fn='bn'):
super().__init__()
self.input_dim = input_dim
self.encoder = MLP([input_dim] + layers + [feature_dim], ac_fn=ac_fn, norm_fn=norm_fn)
nn.init.constant_(self.encoder[-1].bias, 0.0)
def forward(self, kpts, scores=None):
if self.input_dim == 2:
return self.encoder(kpts.transpose(1, 2))
else:
inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] # [B, 2, N] + [B, 1, N]
return self.encoder(torch.cat(inputs, dim=1))