import torch from transformers import PreTrainedModel from .extra_fns import ACT2FN from .encoderblocks import EncoderBlocks from .config import AbLangConfig class AbEmbeddings(PreTrainedModel): def __init__(self, config): super().__init__(config) self.pad_token_id = config.ptid self.AAEmbeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.pad_token_id) self.PositionEmbeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) # here padding_idx is always 0 self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.Dropout = torch.nn.Dropout(config.hidden_dropout_prob) def forward(self, src): inputs_embeds = self.AAEmbeddings(src) position_ids = self.create_position_ids_from_input_ids(src, self.pad_token_id) position_embeddings = self.PositionEmbeddings(position_ids) embeddings = inputs_embeds + position_embeddings return self.Dropout(self.LayerNorm(embeddings)) def create_position_ids_from_input_ids(self, input_ids, padding_idx): """ Replace non-padding symbols with their position numbers. Padding idx will get position 0, which will be ignored later on. """ mask = input_ids.ne(padding_idx).int() return torch.cumsum(mask, dim=1).long() * mask class AbLang(PreTrainedModel): config_class = AbLangConfig def __init__(self, config): super().__init__(config) self.AbEmbeddings = AbEmbeddings(config) self.EncoderBlocks = EncoderBlocks(config) def forward(self, inputs): src = self.AbEmbeddings(inputs['input_ids']) outputs = self.EncoderBlocks(src, attention_mask=1-inputs['attention_mask'], output_attentions=False) return apply_cls_embeddings(inputs, outputs) def apply_cls_embeddings(inputs, outputs): mask = inputs['attention_mask'].float() d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens # make sep token invisible for i in d: mask[i, d[i]] = 0 mask[:, 0] = 0.0 # make cls token invisible mask = mask.unsqueeze(-1).expand(outputs.last_hidden_state.size()) sum_embeddings = torch.sum(outputs.last_hidden_state * mask, 1) sum_mask = torch.clamp(mask.sum(1), min=1e-9) outputs.last_hidden_state[:, 0, :] = sum_embeddings / sum_mask return outputs