File size: 3,629 Bytes
38f2246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch


def weight_init_normal(module, model):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
        if module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)



class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()
        
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings


class MeanPoolingLayer(nn.Module):
    def __init__(self, 
        hidden_size,
        target_size,
        dropout = 0,
    ):
        super(MeanPoolingLayer, self).__init__()
        self.pool = MeanPooling()
        self.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, target_size),
            nn.Sigmoid()
        )
        
    def forward(self, inputs, mask):
        last_hidden_states = inputs[0]
        feature = self.pool(last_hidden_states, mask)
        outputs = self.fc(feature)
        return outputs



class HSLanguageModel(nn.Module):
    def __init__(self,
        backbone = 'microsoft/deberta-v3-small',
        target_size = 1,
        head_dropout = 0,
        reinit_nlayers = 0,
        freeze_nlayers = 0,
        reinit_head = True,
        grad_checkpointing = False,
    ):
        super(HSLanguageModel, self).__init__()
        
        self.config = AutoConfig.from_pretrained(backbone, output_hidden_states=True)
        self.model = AutoModel.from_pretrained(backbone, config=self.config)
        self.head = MeanPoolingLayer(
            self.config.hidden_size,
            target_size,
            head_dropout
        )
        self.tokenizer = AutoTokenizer.from_pretrained(backbone);
        
        
        if grad_checkpointing == True:
            print('Gradient ckpt enabled')
            self.model.gradient_checkpointing_enable()
            
        if reinit_nlayers > 0:
            # Reinit last n encoder layers
            # [TODO] Check if it is autoencoding model: Bert, Roberta, DistilBert, Albert, XLMRoberta, BertModel
            for layer in self.model.encoder.layer[-reinit_nlayers:]: 
                self._init_weights(layer)
        
        if freeze_nlayers > 0:
            self.model.embeddings.requires_grad_(False)
            self.model.encoder.layer[:freeze_nlayers].requires_grad_(False)
        
        if reinit_head:
            # Reinit layers in head
            self._init_weights(self.head)
        
        
    def _init_weights(self, layer):
        for module in layer.modules():
            init_fn = weight_init_normal
            init_fn(module, self)
    

    def forward(self, inputs):
        outputs = self.model(**inputs)
        outputs = self.head(outputs, inputs['attention_mask'])
        return outputs


if __name__ == '__main__':
    
    model = HSLanguageModel()