Vincentqyw
fix: roma
c74a070
from loguru import logger
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .linear_attention import LinearAttention, FullAttention
class LoFTREncoderLayer(nn.Module):
def __init__(self, d_model, nhead, attention="linear"):
super(LoFTREncoderLayer, self).__init__()
self.dim = d_model // nhead
self.nhead = nhead
# multi-head attention
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.attention = LinearAttention() if attention == "linear" else FullAttention()
self.merge = nn.Linear(d_model, d_model, bias=False)
# feed-forward network
self.mlp = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2, bias=False),
nn.GELU(),
nn.Linear(d_model * 2, d_model, bias=False),
)
# norm and dropout
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, source, x_mask=None, source_mask=None):
"""
Args:
x (torch.Tensor): [N, L, C]
source (torch.Tensor): [N, S, C]
x_mask (torch.Tensor): [N, L] (optional)
source_mask (torch.Tensor): [N, S] (optional)
"""
bs = x.shape[0]
query, key, value = x, source, source
# multi-head attention
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
message = self.attention(
query, key, value, q_mask=x_mask, kv_mask=source_mask
) # [N, L, (H, D)]
message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C]
message = self.norm1(message)
# feed-forward network
message = self.mlp(torch.cat([x, message], dim=2))
message = self.norm2(message)
return x + message
class TopicFormer(nn.Module):
"""A Local Feature Transformer (LoFTR) module."""
def __init__(self, config):
super(TopicFormer, self).__init__()
self.config = config
self.d_model = config["d_model"]
self.nhead = config["nhead"]
self.layer_names = config["layer_names"]
encoder_layer = LoFTREncoderLayer(
config["d_model"], config["nhead"], config["attention"]
)
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]
)
self.topic_transformers = (
nn.ModuleList(
[
copy.deepcopy(encoder_layer)
for _ in range(2 * config["n_topic_transformers"])
]
)
if config["n_samples"] > 0
else None
) # nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)])
self.n_iter_topic_transformer = config["n_topic_transformers"]
self.seed_tokens = nn.Parameter(
torch.randn(config["n_topics"], config["d_model"])
)
self.register_parameter("seed_tokens", self.seed_tokens)
self.n_samples = config["n_samples"]
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def sample_topic(self, prob_topics, topics, L):
"""
Args:
topics (torch.Tensor): [N, L+S, K]
"""
prob_topics0, prob_topics1 = prob_topics[:, :L], prob_topics[:, L:]
topics0, topics1 = topics[:, :L], topics[:, L:]
theta0 = F.normalize(prob_topics0.sum(dim=1), p=1, dim=-1) # [N, K]
theta1 = F.normalize(prob_topics1.sum(dim=1), p=1, dim=-1)
theta = F.normalize(theta0 * theta1, p=1, dim=-1)
if self.n_samples == 0:
return None
if self.training:
sampled_inds = torch.multinomial(theta, self.n_samples)
sampled_values = torch.gather(theta, dim=-1, index=sampled_inds)
else:
sampled_values, sampled_inds = torch.topk(theta, self.n_samples, dim=-1)
sampled_topics0 = torch.gather(
topics0,
dim=-1,
index=sampled_inds.unsqueeze(1).repeat(1, topics0.shape[1], 1),
)
sampled_topics1 = torch.gather(
topics1,
dim=-1,
index=sampled_inds.unsqueeze(1).repeat(1, topics1.shape[1], 1),
)
return sampled_topics0, sampled_topics1
def reduce_feat(self, feat, topick, N, C):
len_topic = topick.sum(dim=-1).int()
max_len = len_topic.max().item()
selected_ids = topick.bool()
resized_feat = torch.zeros(
(N, max_len, C), dtype=torch.float32, device=feat.device
)
new_mask = torch.zeros_like(resized_feat[..., 0]).bool()
for i in range(N):
new_mask[i, : len_topic[i]] = True
resized_feat[new_mask, :] = feat[selected_ids, :]
return resized_feat, new_mask, selected_ids
def forward(self, feat0, feat1, mask0=None, mask1=None):
"""
Args:
feat0 (torch.Tensor): [N, L, C]
feat1 (torch.Tensor): [N, S, C]
mask0 (torch.Tensor): [N, L] (optional)
mask1 (torch.Tensor): [N, S] (optional)
"""
assert (
self.d_model == feat0.shape[2]
), "the feature number of src and transformer must be equal"
N, L, S, C, K = (
feat0.shape[0],
feat0.shape[1],
feat1.shape[1],
feat0.shape[2],
self.config["n_topics"],
)
seeds = self.seed_tokens.unsqueeze(0).repeat(N, 1, 1)
feat = torch.cat((feat0, feat1), dim=1)
if mask0 is not None:
mask = torch.cat((mask0, mask1), dim=-1)
else:
mask = None
for layer, name in zip(self.layers, self.layer_names):
if name == "seed":
# seeds = layer(seeds, feat0, None, mask0)
# seeds = layer(seeds, feat1, None, mask1)
seeds = layer(seeds, feat, None, mask)
elif name == "feat":
feat0 = layer(feat0, seeds, mask0, None)
feat1 = layer(feat1, seeds, mask1, None)
dmatrix = torch.einsum("nmd,nkd->nmk", feat, seeds)
prob_topics = F.softmax(dmatrix, dim=-1)
feat_topics = torch.zeros_like(dmatrix).scatter_(
-1, torch.argmax(dmatrix, dim=-1, keepdim=True), 1.0
)
if mask is not None:
feat_topics = feat_topics * mask.unsqueeze(-1)
prob_topics = prob_topics * mask.unsqueeze(-1)
if (feat_topics.detach().sum(dim=1).sum(dim=0) > 100).sum() <= 3:
logger.warning("topic distribution is highly sparse!")
sampled_topics = self.sample_topic(prob_topics.detach(), feat_topics, L)
if sampled_topics is not None:
updated_feat0, updated_feat1 = torch.zeros_like(feat0), torch.zeros_like(
feat1
)
s_topics0, s_topics1 = sampled_topics
for k in range(s_topics0.shape[-1]):
topick0, topick1 = s_topics0[..., k], s_topics1[..., k] # [N, L+S]
if (topick0.sum() > 0) and (topick1.sum() > 0):
new_feat0, new_mask0, selected_ids0 = self.reduce_feat(
feat0, topick0, N, C
)
new_feat1, new_mask1, selected_ids1 = self.reduce_feat(
feat1, topick1, N, C
)
for idt in range(self.n_iter_topic_transformer):
new_feat0 = self.topic_transformers[idt * 2](
new_feat0, new_feat0, new_mask0, new_mask0
)
new_feat1 = self.topic_transformers[idt * 2](
new_feat1, new_feat1, new_mask1, new_mask1
)
new_feat0 = self.topic_transformers[idt * 2 + 1](
new_feat0, new_feat1, new_mask0, new_mask1
)
new_feat1 = self.topic_transformers[idt * 2 + 1](
new_feat1, new_feat0, new_mask1, new_mask0
)
updated_feat0[selected_ids0, :] = new_feat0[new_mask0, :]
updated_feat1[selected_ids1, :] = new_feat1[new_mask1, :]
feat0 = (1 - s_topics0.sum(dim=-1, keepdim=True)) * feat0 + updated_feat0
feat1 = (1 - s_topics1.sum(dim=-1, keepdim=True)) * feat1 + updated_feat1
conf_matrix = (
torch.einsum("nlc,nsc->nls", feat0, feat1) / C**0.5
) # (C * temperature)
if self.training:
topic_matrix = torch.einsum(
"nlk,nsk->nls", prob_topics[:, :L], prob_topics[:, L:]
)
outlier_mask = torch.einsum(
"nlk,nsk->nls", feat_topics[:, :L], feat_topics[:, L:]
)
else:
topic_matrix = {"img0": feat_topics[:, :L], "img1": feat_topics[:, L:]}
outlier_mask = torch.ones_like(conf_matrix)
if mask0 is not None:
outlier_mask = outlier_mask * mask0[..., None] * mask1[:, None] # .bool()
conf_matrix.masked_fill_(~outlier_mask.bool(), -1e9)
conf_matrix = F.softmax(conf_matrix, 1) * F.softmax(
conf_matrix, 2
) # * topic_matrix
return feat0, feat1, conf_matrix, topic_matrix
class LocalFeatureTransformer(nn.Module):
"""A Local Feature Transformer (LoFTR) module."""
def __init__(self, config):
super(LocalFeatureTransformer, self).__init__()
self.config = config
self.d_model = config["d_model"]
self.nhead = config["nhead"]
self.layer_names = config["layer_names"]
encoder_layer = LoFTREncoderLayer(
config["d_model"], config["nhead"], config["attention"]
)
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for _ in range(2)]
) # len(self.layer_names))])
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat0, feat1, mask0=None, mask1=None):
"""
Args:
feat0 (torch.Tensor): [N, L, C]
feat1 (torch.Tensor): [N, S, C]
mask0 (torch.Tensor): [N, L] (optional)
mask1 (torch.Tensor): [N, S] (optional)
"""
assert (
self.d_model == feat0.shape[2]
), "the feature number of src and transformer must be equal"
feat0 = self.layers[0](feat0, feat1, mask0, mask1)
feat1 = self.layers[1](feat1, feat0, mask1, mask0)
return feat0, feat1