# coding=utf-8 # author: xusong # time: 2021/9/17 20:02 """ 1. self.embed_scale 2. """ import math import torch from transformers.models.bert.modeling_bert import BertForMaskedLM, BertEmbeddings, BertModel, BertForMaskedLM, \ BertEncoder, BertPooler, BertOnlyMLMHead, BertConfig, logger from transformers import MODEL_FOR_MASKED_LM_MAPPING class KplugEmbeddings(BertEmbeddings): def __init__(self, config): super().__init__(config) self.embed_scale = math.sqrt(config.hidden_size) # if config.scale_embedding else 1.0 def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 ): if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] seq_length = input_shape[1] if position_ids is None: position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # issue #5664 if token_type_ids is None: if hasattr(self, "token_type_ids"): buffered_token_type_ids = self.token_type_ids[:, :seq_length] buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) * self.embed_scale token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class KplugModel(BertModel): def __init__(self, config, add_pooling_layer=True): super(BertModel, self).__init__(config) self.config = config self.embeddings = KplugEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.init_weights() class KplugForMaskedLM(BertForMaskedLM): def __init__(self, config): super(BertForMaskedLM, self).__init__(config) if config.is_decoder: logger.warning( "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " "bi-directional self-attention." ) self.bert = KplugModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() MODEL_FOR_MASKED_LM_MAPPING[BertConfig] = KplugForMaskedLM