# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. from functools import partial import torch import torch.nn as nn import torch.utils.checkpoint as cp from ...ops.modules import MSDeformAttn from .drop_path import DropPath def get_reference_points(spatial_shapes, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid( torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), ) ref_y = ref_y.reshape(-1)[None] / H_ ref_x = ref_x.reshape(-1)[None] / W_ ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] return reference_points def deform_inputs(x, patch_size): bs, c, h, w = x.shape spatial_shapes = torch.as_tensor( [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device ) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device) deform_inputs1 = [reference_points, spatial_shapes, level_start_index] spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device) deform_inputs2 = [reference_points, spatial_shapes, level_start_index] return deform_inputs1, deform_inputs2 class ConvFFN(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.dwconv = DWConv(hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x, H, W): x = self.fc1(x) x = self.dwconv(x, H, W) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class DWConv(nn.Module): def __init__(self, dim=768): super().__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, x, H, W): B, N, C = x.shape n = N // 21 x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous() x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous() x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() x1 = self.dwconv(x1).flatten(2).transpose(1, 2) x2 = self.dwconv(x2).flatten(2).transpose(1, 2) x3 = self.dwconv(x3).flatten(2).transpose(1, 2) x = torch.cat([x1, x2, x3], dim=1) return x class Extractor(nn.Module): def __init__( self, dim, num_heads=6, n_points=4, n_levels=1, deform_ratio=1.0, with_cffn=True, cffn_ratio=0.25, drop=0.0, drop_path=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-6), with_cp=False, ): super().__init__() self.query_norm = norm_layer(dim) self.feat_norm = norm_layer(dim) self.attn = MSDeformAttn( d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio ) self.with_cffn = with_cffn self.with_cp = with_cp if with_cffn: self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop) self.ffn_norm = norm_layer(dim) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W): def _inner_forward(query, feat): attn = self.attn( self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None ) query = query + attn if self.with_cffn: query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) return query if self.with_cp and query.requires_grad: query = cp.checkpoint(_inner_forward, query, feat) else: query = _inner_forward(query, feat) return query class Injector(nn.Module): def __init__( self, dim, num_heads=6, n_points=4, n_levels=1, deform_ratio=1.0, norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=0.0, with_cp=False, ): super().__init__() self.with_cp = with_cp self.query_norm = norm_layer(dim) self.feat_norm = norm_layer(dim) self.attn = MSDeformAttn( d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio ) self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): def _inner_forward(query, feat): attn = self.attn( self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None ) return query + self.gamma * attn if self.with_cp and query.requires_grad: query = cp.checkpoint(_inner_forward, query, feat) else: query = _inner_forward(query, feat) return query class InteractionBlock(nn.Module): def __init__( self, dim, num_heads=6, n_points=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), drop=0.0, drop_path=0.0, with_cffn=True, cffn_ratio=0.25, init_values=0.0, deform_ratio=1.0, extra_extractor=False, with_cp=False, ): super().__init__() self.injector = Injector( dim=dim, n_levels=3, num_heads=num_heads, init_values=init_values, n_points=n_points, norm_layer=norm_layer, deform_ratio=deform_ratio, with_cp=with_cp, ) self.extractor = Extractor( dim=dim, n_levels=1, num_heads=num_heads, n_points=n_points, norm_layer=norm_layer, deform_ratio=deform_ratio, with_cffn=with_cffn, cffn_ratio=cffn_ratio, drop=drop, drop_path=drop_path, with_cp=with_cp, ) if extra_extractor: self.extra_extractors = nn.Sequential( *[ Extractor( dim=dim, num_heads=num_heads, n_points=n_points, norm_layer=norm_layer, with_cffn=with_cffn, cffn_ratio=cffn_ratio, deform_ratio=deform_ratio, drop=drop, drop_path=drop_path, with_cp=with_cp, ) for _ in range(2) ] ) else: self.extra_extractors = None def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): x = self.injector( query=x, reference_points=deform_inputs1[0], feat=c, spatial_shapes=deform_inputs1[1], level_start_index=deform_inputs1[2], ) for idx, blk in enumerate(blocks): x = blk(x, H_toks, W_toks) c = self.extractor( query=c, reference_points=deform_inputs2[0], feat=x, spatial_shapes=deform_inputs2[1], level_start_index=deform_inputs2[2], H=H_c, W=W_c, ) if self.extra_extractors is not None: for extractor in self.extra_extractors: c = extractor( query=c, reference_points=deform_inputs2[0], feat=x, spatial_shapes=deform_inputs2[1], level_start_index=deform_inputs2[2], H=H_c, W=W_c, ) return x, c class InteractionBlockWithCls(nn.Module): def __init__( self, dim, num_heads=6, n_points=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), drop=0.0, drop_path=0.0, with_cffn=True, cffn_ratio=0.25, init_values=0.0, deform_ratio=1.0, extra_extractor=False, with_cp=False, ): super().__init__() self.injector = Injector( dim=dim, n_levels=3, num_heads=num_heads, init_values=init_values, n_points=n_points, norm_layer=norm_layer, deform_ratio=deform_ratio, with_cp=with_cp, ) self.extractor = Extractor( dim=dim, n_levels=1, num_heads=num_heads, n_points=n_points, norm_layer=norm_layer, deform_ratio=deform_ratio, with_cffn=with_cffn, cffn_ratio=cffn_ratio, drop=drop, drop_path=drop_path, with_cp=with_cp, ) if extra_extractor: self.extra_extractors = nn.Sequential( *[ Extractor( dim=dim, num_heads=num_heads, n_points=n_points, norm_layer=norm_layer, with_cffn=with_cffn, cffn_ratio=cffn_ratio, deform_ratio=deform_ratio, drop=drop, drop_path=drop_path, with_cp=with_cp, ) for _ in range(2) ] ) else: self.extra_extractors = None def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): x = self.injector( query=x, reference_points=deform_inputs1[0], feat=c, spatial_shapes=deform_inputs1[1], level_start_index=deform_inputs1[2], ) x = torch.cat((cls, x), dim=1) for idx, blk in enumerate(blocks): x = blk(x, H_toks, W_toks) cls, x = ( x[ :, :1, ], x[ :, 1:, ], ) c = self.extractor( query=c, reference_points=deform_inputs2[0], feat=x, spatial_shapes=deform_inputs2[1], level_start_index=deform_inputs2[2], H=H_c, W=W_c, ) if self.extra_extractors is not None: for extractor in self.extra_extractors: c = extractor( query=c, reference_points=deform_inputs2[0], feat=x, spatial_shapes=deform_inputs2[1], level_start_index=deform_inputs2[2], H=H_c, W=W_c, ) return x, c, cls class SpatialPriorModule(nn.Module): def __init__(self, inplanes=64, embed_dim=384, with_cp=False): super().__init__() self.with_cp = with_cp self.stem = nn.Sequential( *[ nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.SyncBatchNorm(inplanes), nn.ReLU(inplace=True), nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), nn.SyncBatchNorm(inplanes), nn.ReLU(inplace=True), nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), nn.SyncBatchNorm(inplanes), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), ] ) self.conv2 = nn.Sequential( *[ nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.SyncBatchNorm(2 * inplanes), nn.ReLU(inplace=True), ] ) self.conv3 = nn.Sequential( *[ nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.SyncBatchNorm(4 * inplanes), nn.ReLU(inplace=True), ] ) self.conv4 = nn.Sequential( *[ nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.SyncBatchNorm(4 * inplanes), nn.ReLU(inplace=True), ] ) self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) def forward(self, x): def _inner_forward(x): c1 = self.stem(x) c2 = self.conv2(c1) c3 = self.conv3(c2) c4 = self.conv4(c3) c1 = self.fc1(c1) c2 = self.fc2(c2) c3 = self.fc3(c3) c4 = self.fc4(c4) bs, dim, _, _ = c1.shape # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s return c1, c2, c3, c4 if self.with_cp and x.requires_grad: outs = cp.checkpoint(_inner_forward, x) else: outs = _inner_forward(x) return outs