import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from timm.models.layers import trunc_normal_ from segmenter_model.blocks import Block, FeedForward from segmenter_model.utils import init_weights class DecoderLinear(nn.Module): def __init__(self, n_cls, patch_size, d_encoder): super().__init__() self.d_encoder = d_encoder self.patch_size = patch_size self.n_cls = n_cls self.head = nn.Linear(self.d_encoder, n_cls) self.apply(init_weights) @torch.jit.ignore def no_weight_decay(self): return set() def forward(self, x, im_size): H, W = im_size GS = H // self.patch_size x = self.head(x) x = rearrange(x, "b (h w) c -> b c h w", h=GS) return x class MaskTransformer(nn.Module): def __init__( self, n_cls, patch_size, d_encoder, n_layers, n_heads, d_model, d_ff, drop_path_rate, dropout, ): super().__init__() self.d_encoder = d_encoder self.patch_size = patch_size self.n_layers = n_layers self.n_cls = n_cls self.d_model = d_model self.d_ff = d_ff self.scale = d_model ** -0.5 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] self.blocks = nn.ModuleList( [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)] ) self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model)) self.proj_dec = nn.Linear(d_encoder, d_model) self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model)) self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model)) self.decoder_norm = nn.LayerNorm(d_model) self.mask_norm = nn.LayerNorm(n_cls) self.apply(init_weights) trunc_normal_(self.cls_emb, std=0.02) @torch.jit.ignore def no_weight_decay(self): return {"cls_emb"} def forward(self, x, im_size, features_only=False, no_rearrange=False): H, W = im_size GS = H // self.patch_size # project from the encoder dimensionality to the decoder dimensionality (usually the same) x = self.proj_dec(x) # reshape the class embedding token cls_emb = self.cls_emb.expand(x.size(0), -1, -1) # concatenate the class embedding token to the input x = torch.cat((x, cls_emb), 1) # forward the concatenated tokens through decoder blocks for blk in self.blocks: x = blk(x) # perform normalization x = self.decoder_norm(x) # split to patch features and class-segmentation features patches, cls_seg_feat = x[:, : -self.n_cls], x[:, -self.n_cls:] # project the patch features patches = patches @ self.proj_patch if features_only: if not no_rearrange: features = rearrange(patches, "b (h w) n -> b n h w", h=int(GS)) else: features = patches return features # project the class-segmentation features cls_seg_feat = cls_seg_feat @ self.proj_classes # scalar product between L2-normalized patch embeddings and class embeddings -> masks patches = patches / patches.norm(dim=-1, keepdim=True) cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True) masks = patches @ cls_seg_feat.transpose(1, 2) masks = self.mask_norm(masks) if not no_rearrange: masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS)) return masks def get_attention_map(self, x, layer_id): if layer_id >= self.n_layers or layer_id < 0: raise ValueError( f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}." ) x = self.proj_dec(x) cls_emb = self.cls_emb.expand(x.size(0), -1, -1) x = torch.cat((x, cls_emb), 1) for i, blk in enumerate(self.blocks): if i < layer_id: x = blk(x) else: return blk(x, return_attention=True) class DeepLabHead(nn.Sequential): def __init__(self, in_channels, num_classes, patch_size=None): super(DeepLabHead, self).__init__( ASPP(in_channels, [12, 24, 36]), nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, num_classes, 1) ) self.patch_size = patch_size def forward(self, x, im_size=None): if len(x.shape) == 3: # features from ViT assert im_size is not None and self.patch_size is not None H, W = im_size GS = H // self.patch_size x = rearrange(x, "b (h w) n -> b n h w", h=int(GS)).contiguous() for module in self: x = module(x) return x class ASPPConv(nn.Sequential): def __init__(self, in_channels, out_channels, dilation): modules = [ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU() ] super(ASPPConv, self).__init__(*modules) class ASPPPooling(nn.Sequential): def __init__(self, in_channels, out_channels): super(ASPPPooling, self).__init__( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU()) def forward(self, x): size = x.shape[-2:] for mod in self: x = mod(x) return F.interpolate(x, size=size, mode='bilinear', align_corners=False) class ASPP(nn.Module): def __init__(self, in_channels, atrous_rates, out_channels=256): super(ASPP, self).__init__() modules = [] modules.append(nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())) rates = tuple(atrous_rates) for rate in rates: modules.append(ASPPConv(in_channels, out_channels, rate)) modules.append(ASPPPooling(in_channels, out_channels)) self.convs = nn.ModuleList(modules) self.project = nn.Sequential( nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.Dropout(0.5)) def forward(self, x): res = [] for conv in self.convs: res.append(conv(x)) res = torch.cat(res, dim=1) return self.project(res)