|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import copy |
|
import logging |
|
import math |
|
|
|
from os.path import join as pjoin |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm |
|
from torch.nn.modules.utils import _pair |
|
from scipy import ndimage |
|
|
|
import models.configs as configs |
|
from models.attention import Attention |
|
from models.embed import Embeddings |
|
from models.mlp import Mlp |
|
|
|
ATTENTION_Q = "MultiHeadDotProductAttention_1/query" |
|
ATTENTION_K = "MultiHeadDotProductAttention_1/key" |
|
ATTENTION_V = "MultiHeadDotProductAttention_1/value" |
|
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" |
|
FC_0 = "MlpBlock_3/Dense_0" |
|
FC_1 = "MlpBlock_3/Dense_1" |
|
ATTENTION_NORM = "LayerNorm_0" |
|
MLP_NORM = "LayerNorm_2" |
|
|
|
class Block(nn.Module): |
|
def __init__(self, config, vis, mm=True): |
|
super(Block, self).__init__() |
|
self.hidden_size = config.hidden_size |
|
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) |
|
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) |
|
if mm: |
|
self.att_norm_text = LayerNorm(config.hidden_size, eps=1e-6) |
|
self.ffn_norm_text = LayerNorm(config.hidden_size, eps=1e-6) |
|
self.ffn_text = Mlp(config) |
|
|
|
self.ffn = Mlp(config) |
|
self.attn = Attention(config, vis, mm) |
|
|
|
def forward(self, x, text=None): |
|
if text is None: |
|
h = x |
|
x = self.attention_norm(x) |
|
x, text,weights = self.attn(x) |
|
|
|
x = x + h |
|
|
|
h = x |
|
x = self.ffn_norm(x) |
|
x = self.ffn(x) |
|
x = x + h |
|
return x |
|
else: |
|
h = x |
|
h_text = text |
|
x = self.attention_norm(x) |
|
text = self.att_norm_text(text) |
|
|
|
x, text, weights_img = self.attn(x, text) |
|
|
|
x = x + h |
|
text = text + h_text |
|
|
|
h = x |
|
h_text = text |
|
x = self.ffn_norm(x) |
|
text = self.ffn_norm_text(text) |
|
x = self.ffn(x) |
|
text = self.ffn_text(text) |
|
x = x + h |
|
text = text + h_text |
|
|
|
return x |
|
|
|
def load_from(self, weights, n_block): |
|
ROOT = f"Transformer/encoderblock_{n_block}" |
|
with torch.no_grad(): |
|
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() |
|
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() |
|
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() |
|
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() |
|
|
|
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) |
|
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) |
|
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) |
|
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) |
|
|
|
self.attn.query.weight.copy_(query_weight) |
|
self.attn.key.weight.copy_(key_weight) |
|
self.attn.value.weight.copy_(value_weight) |
|
self.attn.out.weight.copy_(out_weight) |
|
self.attn.query.bias.copy_(query_bias) |
|
self.attn.key.bias.copy_(key_bias) |
|
self.attn.value.bias.copy_(value_bias) |
|
self.attn.out.bias.copy_(out_bias) |
|
|
|
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() |
|
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() |
|
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() |
|
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() |
|
|
|
self.ffn.fc1.weight.copy_(mlp_weight_0) |
|
self.ffn.fc2.weight.copy_(mlp_weight_1) |
|
self.ffn.fc1.bias.copy_(mlp_bias_0) |
|
self.ffn.fc2.bias.copy_(mlp_bias_1) |
|
|
|
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) |
|
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) |
|
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) |
|
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) |
|
|
|
|
|
|