Spaces:
Sleeping
Sleeping
# from transformers.configuration_bert import BertConfig | |
# from transformers import BertPreTrainedModel | |
# from transformers.modeling_bert import BertEmbeddings, BertEncoder, BertPooler, BertLayer, BaseModelOutput, BaseModelOutputWithPooling | |
# from transformers.modeling_bert import BERT_INPUTS_DOCSTRING, _TOKENIZER_FOR_DOC, _CONFIG_FOR_DOC | |
from transformers.models.bert.modeling_bert import BertConfig, BertPreTrainedModel, BertEmbeddings, \ | |
BertPooler, BertLayer, BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions | |
from transformers.models.bert.modeling_bert import BERT_INPUTS_DOCSTRING, _TOKENIZER_FOR_DOC, _CONFIG_FOR_DOC | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
import os | |
import warnings | |
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
import torch | |
import torch.utils.checkpoint | |
from torch import nn | |
from torch.nn import CrossEntropyLoss, MSELoss | |
from transformers.file_utils import ( | |
add_code_sample_docstrings, | |
add_start_docstrings_to_model_forward, | |
) | |
class WordEmbeddingAdapter(nn.Module): | |
def __init__(self, config): | |
super(WordEmbeddingAdapter, self).__init__() | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
self.tanh = nn.Tanh() | |
self.linear1 = nn.Linear(config.word_embed_dim, config.hidden_size) | |
self.linear2 = nn.Linear(config.hidden_size, config.hidden_size) | |
attn_W = torch.zeros(config.hidden_size, config.hidden_size) | |
self.attn_W = nn.Parameter(attn_W) | |
self.attn_W.data.normal_(mean=0.0, std=config.initializer_range) | |
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
def forward(self, layer_output, word_embeddings, word_mask): | |
""" | |
:param layer_output:bert layer的输出,[b_size, len_input, d_model] | |
:param word_embeddings:每个汉字对应的词向量集合,[b_size, len_input, num_word, d_word] | |
:param word_mask:每个汉字对应的词向量集合的attention mask, [b_size, len_input, num_word] | |
""" | |
# transform | |
# 将词向量,与字符向量进行维度对齐 | |
word_outputs = self.linear1(word_embeddings) | |
word_outputs = self.tanh(word_outputs) | |
word_outputs = self.linear2(word_outputs) | |
word_outputs = self.dropout(word_outputs) # word_outputs:[b_size, len_input, num_word, d_model] | |
# if type(word_mask) == torch.long: | |
word_mask = word_mask.bool() | |
# 计算每个字符向量,与其对应的所有词向量的注意力权重,然后加权求和。采用双线性映射计算注意力权重 | |
# layer_output = layer_output.unsqueeze(2) # layer_output:[b_size, len_input, 1, d_model] | |
socres = torch.matmul(layer_output.unsqueeze(2), self.attn_W) # [b_size, len_input, 1, d_model] | |
socres = torch.matmul(socres, torch.transpose(word_outputs, 2, 3)) # [b_size, len_input, 1, num_word] | |
socres = socres.squeeze(2) # [b_size, len_input, num_word] | |
socres.masked_fill_(word_mask, -1e9) # 将pad的注意力设为很小的数 | |
socres = F.softmax(socres, dim=-1) # [b_size, len_input, num_word] | |
attn = socres.unsqueeze(-1) # [b_size, len_input, num_word, 1] | |
weighted_word_embedding = torch.sum(word_outputs * attn, dim=2) # [N, L, D] # 加权求和,得到每个汉字对应的词向量集合的表示 | |
layer_output = layer_output + weighted_word_embedding | |
layer_output = self.dropout(layer_output) | |
layer_output = self.layer_norm(layer_output) | |
return layer_output | |
class LEBertModel(BertPreTrainedModel): | |
""" | |
The model can behave as an encoder (with only self-attention) as well | |
as a decoder, in which case a layer of cross-attention is added between | |
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, | |
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. | |
To behave as an decoder the model needs to be initialized with the | |
:obj:`is_decoder` argument of the configuration set to :obj:`True`. | |
To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` | |
argument and :obj:`add_cross_attention` set to :obj:`True`; an | |
:obj:`encoder_hidden_states` is then expected as an input to the forward pass. | |
.. _`Attention is all you need`: | |
https://arxiv.org/abs/1706.03762 | |
""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.embeddings = BertEmbeddings(config) | |
self.encoder = BertEncoder(config) | |
self.pooler = BertPooler(config) | |
self.init_weights() | |
def get_input_embeddings(self): | |
return self.embeddings.word_embeddings | |
def set_input_embeddings(self, value): | |
self.embeddings.word_embeddings = value | |
def _prune_heads(self, heads_to_prune): | |
"""Prunes heads of the model. | |
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} | |
See base class PreTrainedModel | |
""" | |
for layer, heads in heads_to_prune.items(): | |
self.encoder.layer[layer].attention.prune_heads(heads) | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
word_embeddings=None, | |
word_mask=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
r""" | |
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): | |
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention | |
if the model is configured as a decoder. | |
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | |
Mask to avoid performing attention on the padding token indices of the encoder input. This mask | |
is used in the cross-attention if the model is configured as a decoder. | |
Mask values selected in ``[0, 1]``: | |
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. | |
""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
elif input_ids is not None: | |
input_shape = input_ids.size() | |
elif inputs_embeds is not None: | |
input_shape = inputs_embeds.size()[:-1] | |
else: | |
raise ValueError("You have to specify either input_ids or inputs_embeds") | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
if attention_mask is None: | |
attention_mask = torch.ones(input_shape, device=device) | |
if token_type_ids is None: | |
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) | |
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |
# ourselves in which case we just need to make it broadcastable to all heads. | |
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) | |
# If a 2D ou 3D attention mask is provided for the cross-attention | |
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] | |
if self.config.is_decoder and encoder_hidden_states is not None: | |
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() | |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) | |
if encoder_attention_mask is None: | |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) | |
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) | |
else: | |
encoder_extended_attention_mask = None | |
# Prepare head mask if needed | |
# 1.0 in head_mask indicate we keep the head | |
# attention_probs has shape bsz x n_heads x N x N | |
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] | |
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] | |
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) | |
embedding_output = self.embeddings( | |
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds | |
) | |
encoder_outputs = self.encoder( | |
embedding_output, | |
word_embeddings=word_embeddings, | |
word_mask=word_mask, | |
attention_mask=extended_attention_mask, | |
head_mask=head_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_extended_attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = encoder_outputs[0] | |
pooled_output = self.pooler(sequence_output) | |
if not return_dict: | |
return (sequence_output, pooled_output) + encoder_outputs[1:] | |
return BaseModelOutputWithPoolingAndCrossAttentions( | |
last_hidden_state=sequence_output, | |
pooler_output=pooled_output, | |
hidden_states=encoder_outputs.hidden_states, | |
attentions=encoder_outputs.attentions, | |
) | |
class BertEncoder(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) | |
self.word_embedding_adapter = WordEmbeddingAdapter(config) | |
def forward( | |
self, | |
hidden_states, | |
word_embeddings, | |
word_mask, | |
attention_mask=None, | |
head_mask=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
past_key_values=None, | |
use_cache=None, | |
output_attentions=False, | |
output_hidden_states=False, | |
return_dict=False, | |
): | |
all_hidden_states = () if output_hidden_states else None | |
all_attentions = () if output_attentions else None | |
next_decoder_cache = () if use_cache else None | |
for i, layer_module in enumerate(self.layer): | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
layer_head_mask = head_mask[i] if head_mask is not None else None | |
past_key_value = past_key_values[i] if past_key_values is not None else None | |
if getattr(self.config, "gradient_checkpointing", False): | |
if use_cache: | |
# logger.warning( | |
# "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
# ) | |
use_cache = False | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs, output_attentions) | |
return custom_forward | |
layer_outputs = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(layer_module), | |
hidden_states, | |
attention_mask, | |
layer_head_mask, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
) | |
else: | |
layer_outputs = layer_module( | |
hidden_states, | |
attention_mask, | |
layer_head_mask, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
past_key_value, | |
output_attentions, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache += (layer_outputs[-1],) | |
if output_attentions: | |
all_attentions = all_attentions + (layer_outputs[1],) | |
# 在第i层之后,进行融合 | |
# if i == self.config.add_layer: | |
if i >= int(self.config.add_layer): # edit by wjn | |
hidden_states = self.word_embedding_adapter(hidden_states, word_embeddings, word_mask) | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
# if not return_dict: | |
# return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) | |
if not return_dict: | |
return tuple( | |
v | |
for v in [ | |
hidden_states, | |
next_decoder_cache, | |
all_hidden_states, | |
all_attentions, | |
# all_cross_attentions, | |
] | |
if v is not None | |
) | |
return BaseModelOutputWithPastAndCrossAttentions( | |
last_hidden_state=hidden_states, | |
hidden_states=all_hidden_states, | |
attentions=all_attentions, | |
past_key_values=next_decoder_cache, | |
) | |