kplug / modeling_kplug.py
xusong28
update
2bb0b26
raw history blame
No virus
3.22 kB
# coding=utf-8
# author: xusong <xusong28@jd.com>
# time: 2021/9/17 20:02
"""
1. self.embed_scale
2.
"""
import math
import torch
from transformers.models.bert.modeling_bert import BertForMaskedLM, BertEmbeddings, BertModel, BertForMaskedLM, \
BertEncoder, BertPooler, BertOnlyMLMHead, BertConfig, logger
from transformers import MODEL_FOR_MASKED_LM_MAPPING
class KplugEmbeddings(BertEmbeddings):
def __init__(self, config):
super().__init__(config)
self.embed_scale = math.sqrt(config.hidden_size) # if config.scale_embedding else 1.0
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) * self.embed_scale
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class KplugModel(BertModel):
def __init__(self, config, add_pooling_layer=True):
super(BertModel, self).__init__(config)
self.config = config
self.embeddings = KplugEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
class KplugForMaskedLM(BertForMaskedLM):
def __init__(self, config):
super(BertForMaskedLM, self).__init__(config)
if config.is_decoder:
logger.warning(
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.bert = KplugModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
MODEL_FOR_MASKED_LM_MAPPING[BertConfig] = KplugForMaskedLM