|
|
|
|
|
|
|
|
|
""" |
|
1. self.embed_scale |
|
2. |
|
""" |
|
import math |
|
import torch |
|
from transformers.models.bert.modeling_bert import BertForMaskedLM, BertEmbeddings, BertModel, BertForMaskedLM, \ |
|
BertEncoder, BertPooler, BertOnlyMLMHead, BertConfig, logger |
|
from transformers import MODEL_FOR_MASKED_LM_MAPPING |
|
|
|
|
|
|
|
class KplugEmbeddings(BertEmbeddings): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.embed_scale = math.sqrt(config.hidden_size) |
|
|
|
def forward( |
|
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 |
|
): |
|
if input_ids is not None: |
|
input_shape = input_ids.size() |
|
else: |
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
seq_length = input_shape[1] |
|
|
|
if position_ids is None: |
|
position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length] |
|
|
|
|
|
|
|
|
|
if token_type_ids is None: |
|
if hasattr(self, "token_type_ids"): |
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) |
|
token_type_ids = buffered_token_type_ids_expanded |
|
else: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.word_embeddings(input_ids) * self.embed_scale |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
embeddings = inputs_embeds + token_type_embeddings |
|
if self.position_embedding_type == "absolute": |
|
position_embeddings = self.position_embeddings(position_ids) |
|
embeddings += position_embeddings |
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class KplugModel(BertModel): |
|
|
|
def __init__(self, config, add_pooling_layer=True): |
|
super(BertModel, self).__init__(config) |
|
self.config = config |
|
|
|
self.embeddings = KplugEmbeddings(config) |
|
self.encoder = BertEncoder(config) |
|
|
|
self.pooler = BertPooler(config) if add_pooling_layer else None |
|
|
|
self.init_weights() |
|
|
|
|
|
class KplugForMaskedLM(BertForMaskedLM): |
|
|
|
def __init__(self, config): |
|
super(BertForMaskedLM, self).__init__(config) |
|
|
|
if config.is_decoder: |
|
logger.warning( |
|
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " |
|
"bi-directional self-attention." |
|
) |
|
|
|
self.bert = KplugModel(config, add_pooling_layer=False) |
|
self.cls = BertOnlyMLMHead(config) |
|
|
|
self.init_weights() |
|
|
|
MODEL_FOR_MASKED_LM_MAPPING[BertConfig] = KplugForMaskedLM |
|
|