GECO2-demo / models /query_generator.py
jerpelhan's picture
Initial commit
6146368
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch
from torch import nn
from models.regression_head import UpsamplingLayer
from models.transformer import SelfCrossAttentionBlock, PrototypeAttentionBlock
from models.ops.modules.ms_deform_attn import MSDeformAttn
class C_base(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
num_prototype_attn_steps: int,
num_image_attn_steps: int
) -> None:
"""
Arguments:
"""
super().__init__()
self.transformer_dim = transformer_dim
self.image_attention = nn.ModuleList()
self.image_attention_l1 = nn.ModuleList()
self.image_attention_l2 = nn.ModuleList()
self.prototype_attention = nn.ModuleList()
self.prototype_attention_l1 = nn.ModuleList()
self.prototype_attention_l2 = nn.ModuleList()
for _ in range(num_prototype_attn_steps):
self.prototype_attention.append(
PrototypeAttentionBlock(
embedding_dim=transformer_dim,
num_heads=8,
)
)
self.prototype_attention_l1.append(
PrototypeAttentionBlock(
embedding_dim=transformer_dim,
num_heads=8,
)
)
self.prototype_attention_l2.append(
PrototypeAttentionBlock(
embedding_dim=transformer_dim,
num_heads=8,
)
)
for _ in range(num_image_attn_steps):
self.image_attention.append(MSDeformAttn(
d_model=256, n_levels=1, n_heads=8, n_points=8))
self.image_attention_l1.append(MSDeformAttn(
d_model=256, n_levels=1, n_heads=8, n_points=8))
self.image_attention_l2.append(MSDeformAttn(
d_model=256, n_levels=1, n_heads=8, n_points=8))
self.up1 = UpsamplingLayer(transformer_dim, transformer_dim)
self.up2 = UpsamplingLayer(transformer_dim, transformer_dim)
self.up3 = UpsamplingLayer(transformer_dim, transformer_dim)
self.up_aux = UpsamplingLayer(transformer_dim, transformer_dim)
h,w=64,64
self.spatial_shapes = torch.tensor([[h, w]])
self.valid_ratios = torch.tensor([[1.0, 1.0]])
self.level_start_index = torch.tensor([[0]])
self.spatial_shapes2 = torch.tensor([[h*2, w*2]])
self.valid_ratios2 = torch.tensor([[1.0, 1.0]])
self.level_start_index2 = torch.tensor([[0]])
self.spatial_shapes1 = torch.tensor([[h*4, w*4]])
self.valid_ratios1 = torch.tensor([[1.0, 1.0]])
self.level_start_index1 = torch.tensor([[0]])
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform(m.weight)
m.bias.data.fill_(0.01)
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device='cpu'):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
prototype_embeddings: torch.Tensor,
hq_features: torch.Tensor,
hq_prototypes: torch.Tensor,
hq_pos: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
"""
if self.spatial_shapes.device != image_embeddings.device:
self.spatial_shapes = self.spatial_shapes.to(image_embeddings.device)
self.spatial_shapes1 = self.spatial_shapes1.to(image_embeddings.device)
self.spatial_shapes2 = self.spatial_shapes2.to(image_embeddings.device)
self.level_start_index = self.level_start_index.to(image_embeddings.device)
self.level_start_index1 = self.level_start_index1.to(image_embeddings.device)
self.level_start_index2 = self.level_start_index2.to(image_embeddings.device)
self.valid_ratios = self.valid_ratios.to(image_embeddings.device)
self.valid_ratios1 = self.valid_ratios1.to(image_embeddings.device)
self.valid_ratios2 = self.valid_ratios2.to(image_embeddings.device)
self.reference_points1 = self.get_reference_points(self.spatial_shapes1, self.valid_ratios1, device=image_embeddings.device)
self.reference_points2 = self.get_reference_points(self.spatial_shapes2, self.valid_ratios2, device=image_embeddings.device)
self.reference_points = self.get_reference_points(self.spatial_shapes, self.valid_ratios, device=image_embeddings.device)
b, c, h, w = image_embeddings.shape
image_pe = torch.repeat_interleave(image_pe, image_embeddings.shape[0], dim=0)
image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
src = image_embeddings
hq_features_l1_pos = hq_pos[0].flatten(2).permute(0, 2, 1)
hq_features_l2_pos = hq_pos[1].flatten(2).permute(0, 2, 1)
hq_features_l1 = hq_features[0].flatten(2).permute(0, 2, 1)
hq_features_l2 = hq_features[1].flatten(2).permute(0, 2, 1)
for layer in self.prototype_attention:
src = layer(image_f=src,
prototypes=prototype_embeddings)
for layer in self.prototype_attention_l1:
hq_features_l1 = layer(image_f=hq_features_l1,
prototypes=hq_prototypes[0])
for layer in self.prototype_attention_l2:
hq_features_l2 = layer(image_f=hq_features_l2,
prototypes=hq_prototypes[1])
for layer in self.image_attention:
src = layer((src+image_pe),self.reference_points,src,self.spatial_shapes, self.level_start_index)
for layer in self.image_attention_l1:
hq_features_l1 = layer((hq_features_l1 +hq_features_l1_pos), self.reference_points1, hq_features_l1, self.spatial_shapes1, self.level_start_index1)
for layer in self.image_attention_l2:
hq_features_l2 = layer((hq_features_l2+hq_features_l2_pos), self.reference_points2, hq_features_l2, self.spatial_shapes2, self.level_start_index2)
src = src.transpose(1, 2).reshape(b, c, h, w)
hq_features_l2 = hq_features_l2.transpose(1, 2).view(b, c, h*2, w*2)
hq_features_l1 = hq_features_l1.transpose(1, 2).view(b, c, h*4, w*4)
src = self.up1(src) + hq_features_l2
src = self.up2(src) + hq_features_l1
src = self.up3(src)
src_aux = self.up_aux(hq_features_l1)
return src, src_aux