Spaces:
Running
Running
import copy | |
import logging | |
import math | |
from os.path import join as pjoin | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm | |
from torch.nn.modules.utils import _pair | |
from scipy import ndimage | |
from .utils import get_b16_config | |
from .resnet_v2 import ResNetV2 | |
CONFIGS = { | |
'ViT-B_16': get_b16_config(), | |
#'ViT-B_32': get_b32_config(), | |
#'ViT-L_16': get_l16_config(), | |
#'ViT-L_32': get_l32_config(), | |
#'ViT-H_14': get_h14_config(), | |
#'R50-ViT-B_16': get_r50_b16_config(), | |
#'testing': configs.get_testing(), | |
} | |
ATTENTION_Q = "MultiHeadDotProductAttention_1/query" | |
ATTENTION_K = "MultiHeadDotProductAttention_1/key" | |
ATTENTION_V = "MultiHeadDotProductAttention_1/value" | |
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" | |
FC_0 = "MlpBlock_3/Dense_0" | |
FC_1 = "MlpBlock_3/Dense_1" | |
ATTENTION_NORM = "LayerNorm_0" | |
MLP_NORM = "LayerNorm_2" | |
def np2th(weights, conv=False): | |
"""Possibly convert HWIO to OIHW.""" | |
if conv: | |
weights = weights.transpose([3, 2, 0, 1]) | |
return torch.from_numpy(weights) | |
def swish(x): | |
return x * torch.sigmoid(x) | |
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} | |
class Attention(nn.Module): | |
def __init__(self, config, vis): | |
super(Attention, self).__init__() | |
self.vis = vis | |
self.num_attention_heads = config.transformer["num_heads"] | |
self.attention_head_size = int(config.hidden_size / self.num_attention_heads) | |
self.all_head_size = self.num_attention_heads * self.attention_head_size | |
self.query = Linear(config.hidden_size, self.all_head_size) | |
self.key = Linear(config.hidden_size, self.all_head_size) | |
self.value = Linear(config.hidden_size, self.all_head_size) | |
self.out = Linear(config.hidden_size, config.hidden_size) | |
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) | |
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) | |
self.softmax = Softmax(dim=-1) | |
def transpose_for_scores(self, x): | |
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |
x = x.view(*new_x_shape) | |
return x.permute(0, 2, 1, 3) | |
def forward(self, hidden_states): | |
mixed_query_layer = self.query(hidden_states) | |
mixed_key_layer = self.key(hidden_states) | |
mixed_value_layer = self.value(hidden_states) | |
query_layer = self.transpose_for_scores(mixed_query_layer) | |
key_layer = self.transpose_for_scores(mixed_key_layer) | |
value_layer = self.transpose_for_scores(mixed_value_layer) | |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |
attention_probs = self.softmax(attention_scores) | |
weights = attention_probs if self.vis else None | |
attention_probs = self.attn_dropout(attention_probs) | |
context_layer = torch.matmul(attention_probs, value_layer) | |
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |
context_layer = context_layer.view(*new_context_layer_shape) | |
attention_output = self.out(context_layer) | |
attention_output = self.proj_dropout(attention_output) | |
return attention_output, weights | |
class Mlp(nn.Module): | |
def __init__(self, config): | |
super(Mlp, self).__init__() | |
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) | |
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) | |
self.act_fn = ACT2FN["gelu"] | |
self.dropout = Dropout(config.transformer["dropout_rate"]) | |
self._init_weights() | |
def _init_weights(self): | |
nn.init.xavier_uniform_(self.fc1.weight) | |
nn.init.xavier_uniform_(self.fc2.weight) | |
nn.init.normal_(self.fc1.bias, std=1e-6) | |
nn.init.normal_(self.fc2.bias, std=1e-6) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act_fn(x) | |
x = self.dropout(x) | |
x = self.fc2(x) | |
x = self.dropout(x) | |
return x | |
class Embeddings(nn.Module): | |
"""Construct the embeddings from patch, position embeddings. | |
""" | |
def __init__(self, config, img_size, in_channels=3): | |
super(Embeddings, self).__init__() | |
self.hybrid = None | |
img_size = _pair(img_size) | |
if config.patches.get("grid") is not None: | |
grid_size = config.patches["grid"] | |
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) | |
n_patches = (img_size[0] // 16) * (img_size[1] // 16) | |
self.hybrid = True | |
else: | |
patch_size = _pair(config.patches["size"]) | |
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) | |
self.hybrid = False | |
if self.hybrid: | |
self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, | |
width_factor=config.resnet.width_factor) | |
in_channels = self.hybrid_model.width * 16 | |
self.patch_size = patch_size | |
self.patch_embeddings = Conv2d(in_channels=in_channels, | |
out_channels=config.hidden_size, | |
kernel_size=patch_size, | |
stride=patch_size) | |
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size)) | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) | |
self.dropout = Dropout(config.transformer["dropout_rate"]) | |
def interpolate_pos_encoding(self, x, h, w): | |
npatch = x.shape[1] - 1 | |
N = self.position_embeddings.shape[1] - 1 | |
if npatch == N and w == h: | |
return self.position_embeddings | |
class_pos_embed = self.position_embeddings[:, 0] | |
patch_pos_embed = self.position_embeddings[:, 1:] | |
dim = x.shape[-1] | |
w0 = w // self.patch_size[0] | |
h0 = h // self.patch_size[1] | |
# we add a small number to avoid floating point error in the interpolation | |
# see discussion at https://github.com/facebookresearch/dino/issues/8 | |
w0, h0 = w0 + 0.1, h0 + 0.1 | |
patch_pos_embed = nn.functional.interpolate( | |
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), | |
scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), | |
mode='bicubic', | |
align_corners=False, | |
recompute_scale_factor=False | |
) | |
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] | |
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) | |
def forward(self, x): | |
B, nc, h, w = x.shape | |
cls_tokens = self.cls_token.expand(B, -1, -1) | |
if self.hybrid: | |
x = self.hybrid_model(x) | |
# Linear embedding | |
x = self.patch_embeddings(x) | |
# add the [CLS] token to the embed patch tokens | |
x = x.flatten(2) | |
x = x.transpose(-1, -2) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# add positional encoding to each token | |
embeddings = x + self.interpolate_pos_encoding(x, h, w) | |
embeddings = self.dropout(embeddings) | |
return embeddings | |
class Block(nn.Module): | |
def __init__(self, config, vis): | |
super(Block, self).__init__() | |
self.hidden_size = config.hidden_size | |
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) | |
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) | |
self.ffn = Mlp(config) | |
self.attn = Attention(config, vis) | |
def forward(self, x): | |
h = x | |
x = self.attention_norm(x) | |
x, weights = self.attn(x) | |
x = x + h | |
h = x | |
x = self.ffn_norm(x) | |
x = self.ffn(x) | |
x = x + h | |
return x, weights | |
def load_from(self, weights, n_block): | |
ROOT = f"Transformer/encoderblock_{n_block}" | |
with torch.no_grad(): | |
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() | |
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() | |
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() | |
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() | |
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) | |
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) | |
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) | |
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) | |
self.attn.query.weight.copy_(query_weight) | |
self.attn.key.weight.copy_(key_weight) | |
self.attn.value.weight.copy_(value_weight) | |
self.attn.out.weight.copy_(out_weight) | |
self.attn.query.bias.copy_(query_bias) | |
self.attn.key.bias.copy_(key_bias) | |
self.attn.value.bias.copy_(value_bias) | |
self.attn.out.bias.copy_(out_bias) | |
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() | |
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() | |
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() | |
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() | |
self.ffn.fc1.weight.copy_(mlp_weight_0) | |
self.ffn.fc2.weight.copy_(mlp_weight_1) | |
self.ffn.fc1.bias.copy_(mlp_bias_0) | |
self.ffn.fc2.bias.copy_(mlp_bias_1) | |
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) | |
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) | |
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) | |
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) | |
class Encoder(nn.Module): | |
def __init__(self, config, vis): | |
super(Encoder, self).__init__() | |
self.vis = vis | |
self.layer = nn.ModuleList() | |
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) | |
for _ in range(config.transformer["num_layers"]): | |
layer = Block(config, vis) | |
self.layer.append(copy.deepcopy(layer)) | |
def forward(self, hidden_states): | |
attn_weights = [] | |
for layer_block in self.layer: | |
hidden_states, weights = layer_block(hidden_states) | |
if self.vis: | |
attn_weights.append(weights) | |
encoded = self.encoder_norm(hidden_states) | |
return encoded, attn_weights | |
class Transformer(nn.Module): | |
def __init__(self, config, img_size, vis): | |
super(Transformer, self).__init__() | |
self.embeddings = Embeddings(config, img_size=img_size) | |
self.encoder = Encoder(config, vis) | |
def forward(self, input_ids): | |
embedding_output = self.embeddings(input_ids) | |
encoded, attn_weights = self.encoder(embedding_output) | |
return encoded, attn_weights | |
class VisionTransformer(nn.Module): | |
def __init__(self, config, img_size=224, vis=False): | |
super(VisionTransformer, self).__init__() | |
#self.num_classes = num_classes | |
#self.classifier = config.classifier | |
self.embed_dim = config.hidden_size | |
self.transformer = Transformer(config, img_size, vis) | |
#self.head = Linear(config.hidden_size, num_classes) | |
def forward(self, x, labels=None, use_patches=False): | |
x, attn_weights = self.transformer(x) | |
#logits = self.head(x[:, 0]) | |
#if labels is not None: | |
# loss_fct = CrossEntropyLoss() | |
# loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) | |
# return loss | |
#else: | |
# return logits, attn_weights | |
if use_patches: | |
return x[:, 1:] | |
else: | |
return x[:, 0] | |
def load_from(self, weights): | |
with torch.no_grad(): | |
#if self.zero_head: | |
# nn.init.zeros_(self.head.weight) | |
# nn.init.zeros_(self.head.bias) | |
#else: | |
# self.head.weight.copy_(np2th(weights["head/kernel"]).t()) | |
# self.head.bias.copy_(np2th(weights["head/bias"]).t()) | |
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) | |
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) | |
self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"])) | |
self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) | |
self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) | |
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) | |
posemb_new = self.transformer.embeddings.position_embeddings | |
if posemb.size() == posemb_new.size(): | |
self.transformer.embeddings.position_embeddings.copy_(posemb) | |
else: | |
print("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) | |
ntok_new = posemb_new.size(1) | |
if self.classifier == "token": | |
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] | |
ntok_new -= 1 | |
else: | |
posemb_tok, posemb_grid = posemb[:, :0], posemb[0] | |
gs_old = int(np.sqrt(len(posemb_grid))) | |
gs_new = int(np.sqrt(ntok_new)) | |
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) | |
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) | |
zoom = (gs_new / gs_old, gs_new / gs_old, 1) | |
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) | |
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) | |
posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) | |
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) | |
for bname, block in self.transformer.encoder.named_children(): | |
for uname, unit in block.named_children(): | |
unit.load_from(weights, n_block=uname) | |
if self.transformer.embeddings.hybrid: | |
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True)) | |
gn_weight = np2th(weights["gn_root/scale"]).view(-1) | |
gn_bias = np2th(weights["gn_root/bias"]).view(-1) | |
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) | |
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) | |
for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): | |
for uname, unit in block.named_children(): | |
unit.load_from(weights, n_block=bname, n_unit=uname) | |