| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import BertModel |
| |
|
| | from .scalar_mix import ScalarMix |
| |
|
| |
|
| | class BertEmbedding(nn.Module): |
| |
|
| | def __init__(self, model, n_layers, n_out, requires_grad=False): |
| | proxies = { |
| | "http": "http://10.10.1.10:3128", |
| | "https": "https://10.10.1.10:1080", |
| | } |
| | super(BertEmbedding, self).__init__() |
| |
|
| | |
| | self.bert = BertModel.from_pretrained(model, output_hidden_states=True) |
| |
|
| | self.bert = self.bert.requires_grad_(requires_grad) |
| | self.n_layers = n_layers |
| | self.n_out = n_out |
| | self.requires_grad = requires_grad |
| | self.hidden_size = self.bert.config.hidden_size |
| |
|
| | self.scalar_mix = ScalarMix(n_layers) |
| | self.projection = nn.Linear(self.hidden_size, n_out, False) |
| |
|
| | def __repr__(self): |
| | s = self.__class__.__name__ + '(' |
| | s += f"n_layers={self.n_layers}, n_out={self.n_out}" |
| | if self.requires_grad: |
| | s += f", requires_grad={self.requires_grad}" |
| | s += ')' |
| |
|
| | return s |
| |
|
| | def forward(self, subwords, bert_lens, bert_mask): |
| | batch_size, seq_len = bert_lens.shape |
| | mask = bert_lens.gt(0) |
| |
|
| | if not self.requires_grad: |
| | self.bert.eval() |
| | |
| | out = self.bert(subwords, attention_mask=bert_mask) |
| | |
| | |
| | |
| | _,_,bert = self.bert(subwords, attention_mask=bert_mask) |
| | bert = bert[-self.n_layers:] |
| | |
| | bert = self.scalar_mix(bert) |
| | |
| | bert = bert[bert_mask].split(bert_lens[mask].tolist()) |
| | bert = torch.stack([i.mean(0) for i in bert]) |
| | bert_embed = bert.new_zeros(batch_size, seq_len, self.hidden_size) |
| | bert_embed = bert_embed.masked_scatter_(mask.unsqueeze(-1), bert) |
| | bert_embed = self.projection(bert_embed) |
| |
|
| | return bert_embed |
| |
|