|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from transformers import AutoTokenizer |
|
import pickle |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("DistillMDPI1/DistillMDPI1/saved_tokenizer") |
|
|
|
|
|
|
|
|
|
tokenizer.add_special_tokens({'additional_special_tokens': ['<MULT>']}) |
|
mult_token_id = tokenizer.convert_tokens_to_ids('<MULT>') |
|
cls_token_id = tokenizer.cls_token_id |
|
sep_token_id = tokenizer.sep_token_id |
|
pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
maxlen = 255 |
|
batch_size = 32 |
|
max_pred = 5 |
|
n_layers = 6 |
|
n_heads = 12 |
|
d_model = 768 |
|
d_ff = 768 * 4 |
|
d_k = d_v = 64 |
|
n_segments = 2 |
|
vocab_size = tokenizer.vocab_size +1 |
|
|
|
def get_attn_pad_mask(seq_q, seq_k): |
|
batch_size, len_q = seq_q.size() |
|
batch_size, len_k = seq_k.size() |
|
|
|
pad_attn_mask = seq_k.data.eq(1).unsqueeze(1) |
|
return pad_attn_mask.expand(batch_size, len_q, len_k) |
|
|
|
class Embedding(nn.Module): |
|
def __init__(self): |
|
super(Embedding, self).__init__() |
|
self.tok_embed = nn.Embedding(vocab_size, d_model) |
|
self.pos_embed = nn.Embedding(maxlen, d_model) |
|
self.seg_embed = nn.Embedding(n_segments, d_model) |
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
def forward(self, x, seg): |
|
seq_len = x.size(1) |
|
pos = torch.arange(seq_len, dtype=torch.long, device=x.device) |
|
pos = pos.unsqueeze(0).expand_as(x) |
|
embedding = self.tok_embed(x) |
|
embedding += self.pos_embed(pos) |
|
embedding += self.seg_embed(seg) |
|
return self.norm(embedding) |
|
|
|
class ScaledDotProductAttention(nn.Module): |
|
def __init__(self): |
|
super(ScaledDotProductAttention, self).__init__() |
|
|
|
def forward(self, Q, K, V, attn_mask): |
|
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) |
|
scores.masked_fill_(attn_mask, -1e9) |
|
attn = nn.Softmax(dim=-1)(scores) |
|
context = torch.matmul(attn, V) |
|
return scores , context, attn |
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self): |
|
super(MultiHeadAttention, self).__init__() |
|
self.W_Q = nn.Linear(d_model, d_k * n_heads) |
|
self.W_K = nn.Linear(d_model, d_k * n_heads) |
|
self.W_V = nn.Linear(d_model, d_v * n_heads) |
|
self.fc = nn.Linear(n_heads * d_v, d_model) |
|
self.norm = nn.LayerNorm(d_model) |
|
def forward(self, Q, K, V, attn_mask): |
|
|
|
residual, batch_size = Q, Q.size(0) |
|
device = Q.device |
|
Q, K, V = Q.to(device), K.to(device), V.to(device) |
|
|
|
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) |
|
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) |
|
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) |
|
|
|
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) |
|
|
|
|
|
scores ,context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask) |
|
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) |
|
output = self.fc(context) |
|
return self.norm(output + residual), attn |
|
|
|
class PoswiseFeedForwardNet(nn.Module): |
|
def __init__(self): |
|
super(PoswiseFeedForwardNet, self).__init__() |
|
self.fc1 = nn.Linear(d_model, d_ff) |
|
self.fc2 = nn.Linear(d_ff, d_model) |
|
self.gelu = nn.GELU() |
|
def forward(self, x): |
|
|
|
return self.fc2(self.gelu(self.fc1(x))) |
|
|
|
class EncoderLayer(nn.Module): |
|
def __init__(self): |
|
super(EncoderLayer, self).__init__() |
|
self.enc_self_attn = MultiHeadAttention() |
|
self.pos_ffn = PoswiseFeedForwardNet() |
|
|
|
def forward(self, enc_inputs, enc_self_attn_mask): |
|
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask.to(enc_inputs.device)) |
|
enc_outputs = self.pos_ffn(enc_outputs) |
|
return enc_outputs, attn |
|
|
|
class BERT(nn.Module): |
|
def __init__(self): |
|
super(BERT, self).__init__() |
|
self.embedding = Embedding() |
|
self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)]) |
|
self.fc = nn.Linear(d_model, d_model) |
|
self.activ1 = nn.Tanh() |
|
self.linear = nn.Linear(d_model, d_model) |
|
self.activ2 = nn.GELU() |
|
self.norm = nn.LayerNorm(d_model) |
|
self.classifier = nn.Linear(d_model, 2) |
|
|
|
embed_weight = self.embedding.tok_embed.weight |
|
n_vocab, n_dim = embed_weight.size() |
|
self.decoder = nn.Linear(n_dim, n_vocab, bias=False) |
|
self.decoder.weight = embed_weight |
|
self.decoder_bias = nn.Parameter(torch.zeros(n_vocab)) |
|
self.mclassifier = nn.Linear(d_model, 17) |
|
|
|
def forward(self, input_ids, segment_ids, masked_pos): |
|
output = self.embedding(input_ids, segment_ids) |
|
enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids).to(output.device) |
|
for layer in self.layers: |
|
output, enc_self_attn = layer(output, enc_self_attn_mask) |
|
|
|
|
|
h_pooled = self.activ1(self.fc(output[:, 0])) |
|
logits_clsf = self.classifier(h_pooled) |
|
|
|
masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) |
|
|
|
h_masked = torch.gather(output, 1, masked_pos) |
|
h_masked = self.norm(self.activ2(self.linear(h_masked))) |
|
logits_lm = self.decoder(h_masked) + self.decoder_bias |
|
|
|
h_mult_sent1 = self.activ1(self.fc(output[:, 1])) |
|
logits_mclsf1 = self.mclassifier(h_mult_sent1) |
|
|
|
mult2_token_id = mult_token_id |
|
mult2_positions = (input_ids == mult2_token_id).nonzero(as_tuple=False) |
|
|
|
assert mult2_positions.size(0) == 2 * input_ids.size(0) |
|
mult2_positions = mult2_positions[1::2][:, 1] |
|
|
|
h_mult_sent2 = output[torch.arange(output.size(0)), mult2_positions] |
|
|
|
logits_mclsf2 = self.mclassifier(h_mult_sent2) |
|
logits_mclsf2 = self.mclassifier(h_mult_sent2) |
|
return logits_lm, logits_clsf , logits_mclsf1 , logits_mclsf2 |