import copy import logging import math from os.path import join as pjoin import torch import torch.nn as nn import numpy as np from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm from torch.nn.modules.utils import _pair from scipy import ndimage from .utils import get_b16_config from .resnet_v2 import ResNetV2 CONFIGS = { 'ViT-B_16': get_b16_config(), #'ViT-B_32': get_b32_config(), #'ViT-L_16': get_l16_config(), #'ViT-L_32': get_l32_config(), #'ViT-H_14': get_h14_config(), #'R50-ViT-B_16': get_r50_b16_config(), #'testing': configs.get_testing(), } ATTENTION_Q = "MultiHeadDotProductAttention_1/query" ATTENTION_K = "MultiHeadDotProductAttention_1/key" ATTENTION_V = "MultiHeadDotProductAttention_1/value" ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" FC_0 = "MlpBlock_3/Dense_0" FC_1 = "MlpBlock_3/Dense_1" ATTENTION_NORM = "LayerNorm_0" MLP_NORM = "LayerNorm_2" def np2th(weights, conv=False): """Possibly convert HWIO to OIHW.""" if conv: weights = weights.transpose([3, 2, 0, 1]) return torch.from_numpy(weights) def swish(x): return x * torch.sigmoid(x) ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} class Attention(nn.Module): def __init__(self, config, vis): super(Attention, self).__init__() self.vis = vis self.num_attention_heads = config.transformer["num_heads"] self.attention_head_size = int(config.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = Linear(config.hidden_size, self.all_head_size) self.key = Linear(config.hidden_size, self.all_head_size) self.value = Linear(config.hidden_size, self.all_head_size) self.out = Linear(config.hidden_size, config.hidden_size) self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) self.softmax = Softmax(dim=-1) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = self.softmax(attention_scores) weights = attention_probs if self.vis else None attention_probs = self.attn_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) attention_output = self.out(context_layer) attention_output = self.proj_dropout(attention_output) return attention_output, weights class Mlp(nn.Module): def __init__(self, config): super(Mlp, self).__init__() self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) self.act_fn = ACT2FN["gelu"] self.dropout = Dropout(config.transformer["dropout_rate"]) self._init_weights() def _init_weights(self): nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.normal_(self.fc1.bias, std=1e-6) nn.init.normal_(self.fc2.bias, std=1e-6) def forward(self, x): x = self.fc1(x) x = self.act_fn(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x class Embeddings(nn.Module): """Construct the embeddings from patch, position embeddings. """ def __init__(self, config, img_size, in_channels=3): super(Embeddings, self).__init__() self.hybrid = None img_size = _pair(img_size) if config.patches.get("grid") is not None: grid_size = config.patches["grid"] patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) n_patches = (img_size[0] // 16) * (img_size[1] // 16) self.hybrid = True else: patch_size = _pair(config.patches["size"]) n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) self.hybrid = False if self.hybrid: self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) in_channels = self.hybrid_model.width * 16 self.patch_size = patch_size self.patch_embeddings = Conv2d(in_channels=in_channels, out_channels=config.hidden_size, kernel_size=patch_size, stride=patch_size) self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.dropout = Dropout(config.transformer["dropout_rate"]) def interpolate_pos_encoding(self, x, h, w): npatch = x.shape[1] - 1 N = self.position_embeddings.shape[1] - 1 if npatch == N and w == h: return self.position_embeddings class_pos_embed = self.position_embeddings[:, 0] patch_pos_embed = self.position_embeddings[:, 1:] dim = x.shape[-1] w0 = w // self.patch_size[0] h0 = h // self.patch_size[1] # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), mode='bicubic', align_corners=False, recompute_scale_factor=False ) assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def forward(self, x): B, nc, h, w = x.shape cls_tokens = self.cls_token.expand(B, -1, -1) if self.hybrid: x = self.hybrid_model(x) # Linear embedding x = self.patch_embeddings(x) # add the [CLS] token to the embed patch tokens x = x.flatten(2) x = x.transpose(-1, -2) x = torch.cat((cls_tokens, x), dim=1) # add positional encoding to each token embeddings = x + self.interpolate_pos_encoding(x, h, w) embeddings = self.dropout(embeddings) return embeddings class Block(nn.Module): def __init__(self, config, vis): super(Block, self).__init__() self.hidden_size = config.hidden_size self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) self.ffn = Mlp(config) self.attn = Attention(config, vis) def forward(self, x): h = x x = self.attention_norm(x) x, weights = self.attn(x) x = x + h h = x x = self.ffn_norm(x) x = self.ffn(x) x = x + h return x, weights def load_from(self, weights, n_block): ROOT = f"Transformer/encoderblock_{n_block}" with torch.no_grad(): query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) self.attn.query.weight.copy_(query_weight) self.attn.key.weight.copy_(key_weight) self.attn.value.weight.copy_(value_weight) self.attn.out.weight.copy_(out_weight) self.attn.query.bias.copy_(query_bias) self.attn.key.bias.copy_(key_bias) self.attn.value.bias.copy_(value_bias) self.attn.out.bias.copy_(out_bias) mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() self.ffn.fc1.weight.copy_(mlp_weight_0) self.ffn.fc2.weight.copy_(mlp_weight_1) self.ffn.fc1.bias.copy_(mlp_bias_0) self.ffn.fc2.bias.copy_(mlp_bias_1) self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) class Encoder(nn.Module): def __init__(self, config, vis): super(Encoder, self).__init__() self.vis = vis self.layer = nn.ModuleList() self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) for _ in range(config.transformer["num_layers"]): layer = Block(config, vis) self.layer.append(copy.deepcopy(layer)) def forward(self, hidden_states): attn_weights = [] for layer_block in self.layer: hidden_states, weights = layer_block(hidden_states) if self.vis: attn_weights.append(weights) encoded = self.encoder_norm(hidden_states) return encoded, attn_weights class Transformer(nn.Module): def __init__(self, config, img_size, vis): super(Transformer, self).__init__() self.embeddings = Embeddings(config, img_size=img_size) self.encoder = Encoder(config, vis) def forward(self, input_ids): embedding_output = self.embeddings(input_ids) encoded, attn_weights = self.encoder(embedding_output) return encoded, attn_weights class VisionTransformer(nn.Module): def __init__(self, config, img_size=224, vis=False): super(VisionTransformer, self).__init__() #self.num_classes = num_classes #self.classifier = config.classifier self.embed_dim = config.hidden_size self.transformer = Transformer(config, img_size, vis) #self.head = Linear(config.hidden_size, num_classes) def forward(self, x, labels=None, use_patches=False): x, attn_weights = self.transformer(x) #logits = self.head(x[:, 0]) #if labels is not None: # loss_fct = CrossEntropyLoss() # loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) # return loss #else: # return logits, attn_weights if use_patches: return x[:, 1:] else: return x[:, 0] def load_from(self, weights): with torch.no_grad(): #if self.zero_head: # nn.init.zeros_(self.head.weight) # nn.init.zeros_(self.head.bias) #else: # self.head.weight.copy_(np2th(weights["head/kernel"]).t()) # self.head.bias.copy_(np2th(weights["head/bias"]).t()) self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"])) self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) posemb_new = self.transformer.embeddings.position_embeddings if posemb.size() == posemb_new.size(): self.transformer.embeddings.position_embeddings.copy_(posemb) else: print("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) ntok_new = posemb_new.size(1) if self.classifier == "token": posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] ntok_new -= 1 else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(np.sqrt(len(posemb_grid))) gs_new = int(np.sqrt(ntok_new)) print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) zoom = (gs_new / gs_old, gs_new / gs_old, 1) posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) for bname, block in self.transformer.encoder.named_children(): for uname, unit in block.named_children(): unit.load_from(weights, n_block=uname) if self.transformer.embeddings.hybrid: self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True)) gn_weight = np2th(weights["gn_root/scale"]).view(-1) gn_bias = np2th(weights["gn_root/bias"]).view(-1) self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): for uname, unit in block.named_children(): unit.load_from(weights, n_block=bname, n_unit=uname)