AttentionMaps / model.py
TexR6's picture
initial commit
d7b0f75
raw
history blame
12.3 kB
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from resnet import StdConv2d
from utils import (get_width_and_height_from_size, load_pretrained_weights,
get_model_params)
VALID_MODELS = ('ViT-B_16', 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'R50+ViT-B_16')
class PositionEmbs(nn.Module):
def __init__(self, num_patches, emb_dim, dropout_rate=0.1):
super(PositionEmbs, self).__init__()
self.pos_embedding = nn.Parameter(
torch.randn(1, num_patches + 1, emb_dim))
if dropout_rate > 0:
self.dropout = nn.Dropout(dropout_rate)
else:
self.dropout = None
def forward(self, x):
out = x + self.pos_embedding
if self.dropout:
out = self.dropout(out)
return out
class MlpBlock(nn.Module):
""" Transformer Feed-Forward Block """
def __init__(self, in_dim, mlp_dim, out_dim, dropout_rate=0.1):
super(MlpBlock, self).__init__()
# init layers
self.fc1 = nn.Linear(in_dim, mlp_dim)
self.fc2 = nn.Linear(mlp_dim, out_dim)
self.act = nn.GELU()
if dropout_rate > 0.0:
self.dropout1 = nn.Dropout(dropout_rate)
self.dropout2 = nn.Dropout(dropout_rate)
else:
self.dropout1 = None
self.dropout2 = None
def forward(self, x):
out = self.fc1(x)
out = self.act(out)
if self.dropout1:
out = self.dropout1(out)
out = self.fc2(out)
out = self.dropout2(out)
return out
class LinearGeneral(nn.Module):
def __init__(self, in_dim=(768, ), feat_dim=(12, 64)):
super(LinearGeneral, self).__init__()
self.weight = nn.Parameter(torch.randn(*in_dim, *feat_dim))
self.bias = nn.Parameter(torch.zeros(*feat_dim))
def forward(self, x, dims):
a = torch.tensordot(x, self.weight, dims=dims) + self.bias
return a
class SelfAttention(nn.Module):
def __init__(self, in_dim, heads=8, dropout_rate=0.1):
super(SelfAttention, self).__init__()
self.heads = heads
self.head_dim = in_dim // heads
self.scale = self.head_dim**0.5
self.query = LinearGeneral((in_dim, ), (self.heads, self.head_dim))
self.key = LinearGeneral((in_dim, ), (self.heads, self.head_dim))
self.value = LinearGeneral((in_dim, ), (self.heads, self.head_dim))
self.out = LinearGeneral((self.heads, self.head_dim), (in_dim, ))
if dropout_rate > 0:
self.dropout = nn.Dropout(dropout_rate)
else:
self.dropout = None
def forward(self, x):
b, n, _ = x.shape
q = self.query(x, dims=([2], [0]))
k = self.key(x, dims=([2], [0]))
v = self.value(x, dims=([2], [0]))
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
out = torch.matmul(attn_weights, v)
out = out.permute(0, 2, 1, 3)
out = self.out(out, dims=([2, 3], [0, 1]))
return out
class EncoderBlock(nn.Module):
def __init__(self,
in_dim,
mlp_dim,
num_heads,
dropout_rate=0.1,
attn_dropout_rate=0.1):
super(EncoderBlock, self).__init__()
self.norm1 = nn.LayerNorm(in_dim)
self.attn = SelfAttention(in_dim,
heads=num_heads,
dropout_rate=attn_dropout_rate)
if dropout_rate > 0:
self.dropout = nn.Dropout(dropout_rate)
else:
self.dropout = None
self.norm2 = nn.LayerNorm(in_dim)
self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate)
def forward(self, x):
residual = x
out = self.norm1(x)
out = self.attn(out)
if self.dropout:
out = self.dropout(out)
out += residual
residual = out
out = self.norm2(out)
out = self.mlp(out)
out += residual
return out
class Encoder(nn.Module):
def __init__(self,
num_patches,
emb_dim,
mlp_dim,
num_layers=12,
num_heads=12,
dropout_rate=0.1,
attn_dropout_rate=0.0):
super(Encoder, self).__init__()
# positional embedding
self.pos_embedding = PositionEmbs(num_patches, emb_dim, dropout_rate)
# encoder blocks
in_dim = emb_dim
self.encoder_layers = nn.ModuleList()
for i in range(num_layers):
layer = EncoderBlock(in_dim, mlp_dim, num_heads, dropout_rate,
attn_dropout_rate)
self.encoder_layers.append(layer)
self.norm = nn.LayerNorm(in_dim)
def forward(self, x):
out = self.pos_embedding(x)
for layer in self.encoder_layers:
out = layer(out)
out = self.norm(out)
return out
class VisionTransformer(nn.Module):
""" Vision Transformer.
Most easily loaded with the .from_name or .from_pretrained methods.
Args:
params (namedtuple): A set of Params.
References:
[1] https://arxiv.org/abs/2010.11929 (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)
Example:
import torch
>>> from vision_transformer_pytorch import VisionTransformer
>>> inputs = torch.rand(1, 3, 256, 256)
>>> model = VisionTransformer.from_pretrained('ViT-B_16')
>>> model.eval()
>>> outputs = model(inputs)
"""
def __init__(self, params=None):
super(VisionTransformer, self).__init__()
self._params = params
if self._params.resnet:
self.resnet = self._params.resnet()
self.embedding = nn.Conv2d(self.resnet.width * 16,
self._params.emb_dim,
kernel_size=1,
stride=1)
else:
self.embedding = nn.Conv2d(3,
self._params.emb_dim,
kernel_size=self.patch_size,
stride=self.patch_size)
# class token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self._params.emb_dim))
# transformer
self.transformer = Encoder(
num_patches=self.num_patches,
emb_dim=self._params.emb_dim,
mlp_dim=self._params.mlp_dim,
num_layers=self._params.num_layers,
num_heads=self._params.num_heads,
dropout_rate=self._params.dropout_rate,
attn_dropout_rate=self._params.attn_dropout_rate)
# classfier
self.classifier = nn.Linear(self._params.emb_dim,
self._params.num_classes)
@property
def image_size(self):
return get_width_and_height_from_size(self._params.image_size)
@property
def patch_size(self):
return get_width_and_height_from_size(self._params.patch_size)
@property
def num_patches(self):
h, w = self.image_size
fh, fw = self.patch_size
if hasattr(self, 'resnet'):
gh, gw = h // fh // self.resnet.downsample, w // fw // self.resnet.downsample
else:
gh, gw = h // fh, w // fw
return gh * gw
def extract_features(self, x):
if hasattr(self, 'resnet'):
x = self.resnet(x)
emb = self.embedding(x) # (n, c, gh, gw)
emb = emb.permute(0, 2, 3, 1) # (n, gh, hw, c)
b, h, w, c = emb.shape
emb = emb.reshape(b, h * w, c)
# prepend class token
cls_token = self.cls_token.repeat(b, 1, 1)
emb = torch.cat([cls_token, emb], dim=1)
# transformer
feat = self.transformer(emb)
return feat
def forward(self, x):
feat = self.extract_features(x)
# classifier
logits = self.classifier(feat[:, 0])
return logits
@classmethod
def from_name(cls, model_name, in_channels=3, **override_params):
"""create an vision transformer model according to name.
Args:
model_name (str): Name for vision transformer.
in_channels (int): Input data's channel number.
override_params (other key word params):
Params to override model's global_params.
Optional key:
'image_size', 'patch_size',
'emb_dim', 'mlp_dim',
'num_heads', 'num_layers',
'num_classes', 'attn_dropout_rate',
'dropout_rate'
Returns:
An vision transformer model.
"""
cls._check_model_name_is_valid(model_name)
params = get_model_params(model_name, override_params)
model = cls(params)
model._change_in_channels(in_channels)
return model
@classmethod
def from_pretrained(cls,
model_name,
weights_path=None,
in_channels=3,
num_classes=1000,
**override_params):
"""create an vision transformer model according to name.
Args:
model_name (str): Name for vision transformer.
weights_path (None or str):
str: path to pretrained weights file on the local disk.
None: use pretrained weights downloaded from the Internet.
in_channels (int): Input data's channel number.
num_classes (int):
Number of categories for classification.
It controls the output size for final linear layer.
override_params (other key word params):
Params to override model's global_params.
Optional key:
'image_size', 'patch_size',
'emb_dim', 'mlp_dim',
'num_heads', 'num_layers',
'num_classes', 'attn_dropout_rate',
'dropout_rate'
Returns:
A pretrained vision transformer model.
"""
model = cls.from_name(model_name,
num_classes=num_classes,
**override_params)
load_pretrained_weights(model,
model_name,
weights_path=weights_path,
load_fc=(num_classes == 1000))
model._change_in_channels(in_channels)
return model
@classmethod
def _check_model_name_is_valid(cls, model_name):
"""Validates model name.
Args:
model_name (str): Name for vision transformer.
Returns:
bool: Is a valid name or not.
"""
if model_name not in VALID_MODELS:
raise ValueError('model_name should be one of: ' +
', '.join(VALID_MODELS))
def _change_in_channels(self, in_channels):
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
Args:
in_channels (int): Input data's channel number.
"""
if in_channels != 3:
if hasattr(self, 'resnet'):
self.resnet.root['conv'] = StdConv2d(in_channels,
self.resnet.width,
kernel_size=7,
stride=2,
bias=False,
padding=3)
else:
self.embedding = nn.Conv2d(in_channels,
self._params.emb_dim,
kernel_size=self.patch_size,
stride=self.patch_size)