DaS / segmenter_model /decoder.py
vobecant
Initial commit
dd78229
raw
history blame contribute delete
No virus
6.85 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import trunc_normal_
from segmenter_model.blocks import Block, FeedForward
from segmenter_model.utils import init_weights
class DecoderLinear(nn.Module):
def __init__(self, n_cls, patch_size, d_encoder):
super().__init__()
self.d_encoder = d_encoder
self.patch_size = patch_size
self.n_cls = n_cls
self.head = nn.Linear(self.d_encoder, n_cls)
self.apply(init_weights)
@torch.jit.ignore
def no_weight_decay(self):
return set()
def forward(self, x, im_size):
H, W = im_size
GS = H // self.patch_size
x = self.head(x)
x = rearrange(x, "b (h w) c -> b c h w", h=GS)
return x
class MaskTransformer(nn.Module):
def __init__(
self,
n_cls,
patch_size,
d_encoder,
n_layers,
n_heads,
d_model,
d_ff,
drop_path_rate,
dropout,
):
super().__init__()
self.d_encoder = d_encoder
self.patch_size = patch_size
self.n_layers = n_layers
self.n_cls = n_cls
self.d_model = d_model
self.d_ff = d_ff
self.scale = d_model ** -0.5
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
self.blocks = nn.ModuleList(
[Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]
)
self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model))
self.proj_dec = nn.Linear(d_encoder, d_model)
self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model))
self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model))
self.decoder_norm = nn.LayerNorm(d_model)
self.mask_norm = nn.LayerNorm(n_cls)
self.apply(init_weights)
trunc_normal_(self.cls_emb, std=0.02)
@torch.jit.ignore
def no_weight_decay(self):
return {"cls_emb"}
def forward(self, x, im_size, features_only=False, no_rearrange=False):
H, W = im_size
GS = H // self.patch_size
# project from the encoder dimensionality to the decoder dimensionality (usually the same)
x = self.proj_dec(x)
# reshape the class embedding token
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
# concatenate the class embedding token to the input
x = torch.cat((x, cls_emb), 1)
# forward the concatenated tokens through decoder blocks
for blk in self.blocks:
x = blk(x)
# perform normalization
x = self.decoder_norm(x)
# split to patch features and class-segmentation features
patches, cls_seg_feat = x[:, : -self.n_cls], x[:, -self.n_cls:]
# project the patch features
patches = patches @ self.proj_patch
if features_only:
if not no_rearrange:
features = rearrange(patches, "b (h w) n -> b n h w", h=int(GS))
else:
features = patches
return features
# project the class-segmentation features
cls_seg_feat = cls_seg_feat @ self.proj_classes
# scalar product between L2-normalized patch embeddings and class embeddings -> masks
patches = patches / patches.norm(dim=-1, keepdim=True)
cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)
masks = patches @ cls_seg_feat.transpose(1, 2)
masks = self.mask_norm(masks)
if not no_rearrange:
masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))
return masks
def get_attention_map(self, x, layer_id):
if layer_id >= self.n_layers or layer_id < 0:
raise ValueError(
f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}."
)
x = self.proj_dec(x)
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
x = torch.cat((x, cls_emb), 1)
for i, blk in enumerate(self.blocks):
if i < layer_id:
x = blk(x)
else:
return blk(x, return_attention=True)
class DeepLabHead(nn.Sequential):
def __init__(self, in_channels, num_classes, patch_size=None):
super(DeepLabHead, self).__init__(
ASPP(in_channels, [12, 24, 36]),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, num_classes, 1)
)
self.patch_size = patch_size
def forward(self, x, im_size=None):
if len(x.shape) == 3:
# features from ViT
assert im_size is not None and self.patch_size is not None
H, W = im_size
GS = H // self.patch_size
x = rearrange(x, "b (h w) n -> b n h w", h=int(GS)).contiguous()
for module in self:
x = module(x)
return x
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())
def forward(self, x):
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels=256):
super(ASPP, self).__init__()
modules = []
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()))
rates = tuple(atrous_rates)
for rate in rates:
modules.append(ASPPConv(in_channels, out_channels, rate))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5))
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)