File size: 2,813 Bytes
9fc352f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbdddc1
 
 
 
 
 
 
 
 
 
 
 
 
 
9fc352f
fbdddc1
 
9fc352f
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from transformers import PreTrainedModel
from .extra_fns import ACT2FN
from .encoderblocks import EncoderBlocks
from .config import AbLangConfig

class AbEmbeddings(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.pad_token_id = config.ptid
        self.AAEmbeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.pad_token_id)
        self.PositionEmbeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) # here padding_idx is always 0
        self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.Dropout = torch.nn.Dropout(config.hidden_dropout_prob)

    def forward(self, src):
        inputs_embeds = self.AAEmbeddings(src)
        position_ids = self.create_position_ids_from_input_ids(src, self.pad_token_id)   
        position_embeddings = self.PositionEmbeddings(position_ids)
        embeddings = inputs_embeds + position_embeddings
        return self.Dropout(self.LayerNorm(embeddings))
        
    def create_position_ids_from_input_ids(self, input_ids, padding_idx):
        """
        Replace non-padding symbols with their position numbers. Padding idx will get position 0, which will be ignored later on.
        """
        mask = input_ids.ne(padding_idx).int()
        return torch.cumsum(mask, dim=1).long() * mask


class AbLang(PreTrainedModel):
    config_class = AbLangConfig
    def __init__(self, config):
        super().__init__(config)
        self.AbEmbeddings = AbEmbeddings(config)    
        self.EncoderBlocks = EncoderBlocks(config)
        
    def forward(
            self, 
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            output_attentions=None,
            output_hidden_states=None,
        ):
        src = self.AbEmbeddings(input_ids)
        outputs = self.EncoderBlocks(src, 
                                     attention_mask=1-attention_mask, 
                                     output_attentions=output_attentions, 
                                     output_hidden_states=output_hidden_states)
        return apply_cls_embeddings(attention_mask, outputs)
    
def apply_cls_embeddings(attention_mask, outputs):
    mask = attention_mask.float()
    d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
    # make sep token invisible
    for i in d:
        mask[i, d[i]] = 0
    mask[:, 0] = 0.0 # make cls token invisible
    mask = mask.unsqueeze(-1).expand(outputs.last_hidden_state.size())
    sum_embeddings = torch.sum(outputs.last_hidden_state * mask, 1)
    sum_mask = torch.clamp(mask.sum(1), min=1e-9)
    outputs.last_hidden_state[:, 0, :] = sum_embeddings / sum_mask
    return outputs