CHSTR's picture
S3BIR app demo
3e0e9f4
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from functools import partial
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from ...ops.modules import MSDeformAttn
from .drop_path import DropPath
def get_reference_points(spatial_shapes, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / H_
ref_x = ref_x.reshape(-1)[None] / W_
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None]
return reference_points
def deform_inputs(x, patch_size):
bs, c, h, w = x.shape
spatial_shapes = torch.as_tensor(
[(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device
)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device)
deform_inputs1 = [reference_points, spatial_shapes, level_start_index]
spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device)
deform_inputs2 = [reference_points, spatial_shapes, level_start_index]
return deform_inputs1, deform_inputs2
class ConvFFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
n = N // 21
x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous()
x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous()
x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous()
x1 = self.dwconv(x1).flatten(2).transpose(1, 2)
x2 = self.dwconv(x2).flatten(2).transpose(1, 2)
x3 = self.dwconv(x3).flatten(2).transpose(1, 2)
x = torch.cat([x1, x2, x3], dim=1)
return x
class Extractor(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
n_levels=1,
deform_ratio=1.0,
with_cffn=True,
cffn_ratio=0.25,
drop=0.0,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
with_cp=False,
):
super().__init__()
self.query_norm = norm_layer(dim)
self.feat_norm = norm_layer(dim)
self.attn = MSDeformAttn(
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
)
self.with_cffn = with_cffn
self.with_cp = with_cp
if with_cffn:
self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop)
self.ffn_norm = norm_layer(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W):
def _inner_forward(query, feat):
attn = self.attn(
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
)
query = query + attn
if self.with_cffn:
query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W))
return query
if self.with_cp and query.requires_grad:
query = cp.checkpoint(_inner_forward, query, feat)
else:
query = _inner_forward(query, feat)
return query
class Injector(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
n_levels=1,
deform_ratio=1.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=0.0,
with_cp=False,
):
super().__init__()
self.with_cp = with_cp
self.query_norm = norm_layer(dim)
self.feat_norm = norm_layer(dim)
self.attn = MSDeformAttn(
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
)
self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index):
def _inner_forward(query, feat):
attn = self.attn(
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
)
return query + self.gamma * attn
if self.with_cp and query.requires_grad:
query = cp.checkpoint(_inner_forward, query, feat)
else:
query = _inner_forward(query, feat)
return query
class InteractionBlock(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
drop=0.0,
drop_path=0.0,
with_cffn=True,
cffn_ratio=0.25,
init_values=0.0,
deform_ratio=1.0,
extra_extractor=False,
with_cp=False,
):
super().__init__()
self.injector = Injector(
dim=dim,
n_levels=3,
num_heads=num_heads,
init_values=init_values,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cp=with_cp,
)
self.extractor = Extractor(
dim=dim,
n_levels=1,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
if extra_extractor:
self.extra_extractors = nn.Sequential(
*[
Extractor(
dim=dim,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
deform_ratio=deform_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
for _ in range(2)
]
)
else:
self.extra_extractors = None
def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
x = self.injector(
query=x,
reference_points=deform_inputs1[0],
feat=c,
spatial_shapes=deform_inputs1[1],
level_start_index=deform_inputs1[2],
)
for idx, blk in enumerate(blocks):
x = blk(x, H_toks, W_toks)
c = self.extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
if self.extra_extractors is not None:
for extractor in self.extra_extractors:
c = extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
return x, c
class InteractionBlockWithCls(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
drop=0.0,
drop_path=0.0,
with_cffn=True,
cffn_ratio=0.25,
init_values=0.0,
deform_ratio=1.0,
extra_extractor=False,
with_cp=False,
):
super().__init__()
self.injector = Injector(
dim=dim,
n_levels=3,
num_heads=num_heads,
init_values=init_values,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cp=with_cp,
)
self.extractor = Extractor(
dim=dim,
n_levels=1,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
if extra_extractor:
self.extra_extractors = nn.Sequential(
*[
Extractor(
dim=dim,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
deform_ratio=deform_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
for _ in range(2)
]
)
else:
self.extra_extractors = None
def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
x = self.injector(
query=x,
reference_points=deform_inputs1[0],
feat=c,
spatial_shapes=deform_inputs1[1],
level_start_index=deform_inputs1[2],
)
x = torch.cat((cls, x), dim=1)
for idx, blk in enumerate(blocks):
x = blk(x, H_toks, W_toks)
cls, x = (
x[
:,
:1,
],
x[
:,
1:,
],
)
c = self.extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
if self.extra_extractors is not None:
for extractor in self.extra_extractors:
c = extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
return x, c, cls
class SpatialPriorModule(nn.Module):
def __init__(self, inplanes=64, embed_dim=384, with_cp=False):
super().__init__()
self.with_cp = with_cp
self.stem = nn.Sequential(
*[
nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(inplanes),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
nn.SyncBatchNorm(inplanes),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
nn.SyncBatchNorm(inplanes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
]
)
self.conv2 = nn.Sequential(
*[
nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(2 * inplanes),
nn.ReLU(inplace=True),
]
)
self.conv3 = nn.Sequential(
*[
nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(4 * inplanes),
nn.ReLU(inplace=True),
]
)
self.conv4 = nn.Sequential(
*[
nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(4 * inplanes),
nn.ReLU(inplace=True),
]
)
self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, x):
def _inner_forward(x):
c1 = self.stem(x)
c2 = self.conv2(c1)
c3 = self.conv3(c2)
c4 = self.conv4(c3)
c1 = self.fc1(c1)
c2 = self.fc2(c2)
c3 = self.fc3(c3)
c4 = self.fc4(c4)
bs, dim, _, _ = c1.shape
# c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s
c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s
c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s
c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s
return c1, c2, c3, c4
if self.with_cp and x.requires_grad:
outs = cp.checkpoint(_inner_forward, x)
else:
outs = _inner_forward(x)
return outs