File size: 2,298 Bytes
71bb029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189fb58
 
 
 
3647818
095611c
aa65e77
eb0bbcd
189fb58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.utils.checkpoint
from torch import nn

from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel
from .configuration_backpack_gpt2_nli import BackpackGPT2NLIConfig
from .modeling_backpack_gpt2 import BackpackGPT2Model


class BackpackGPT2NLIModel(GPT2PreTrainedModel):
    config_class = BackpackGPT2NLIConfig

    def __init__(self, config):
        super().__init__(config)

        self.backpack = BackpackGPT2Model(config)

        self.n_embd = config.n_embd

        self.num_labels = config.num_labels # 0: Entailment -- 1: Neutral -- 2: Contradiction
        
        self.nli_head = nn.Sequential(
            nn.Linear(self.n_embd, self.n_embd),
            nn.Dropout(0.1),
            nn.Linear(self.n_embd, self.num_labels)
        )

        # Freeze The Encoder if Needed
        self.backpack.requires_grad_(not config.freeze_backpack)

        self.loss_func = nn.CrossEntropyLoss()

        # Model parallel
        self.model_parallel = False


    def forward(self, input_ids=None, attention_mask=None, labels=None):

        backpack_outputs = self.backpack(input_ids=input_ids, position_ids=None)

        backpack_hidden_states, backpack_contextualization = backpack_outputs.hidden_states, backpack_outputs.contextualization
        last_toks_indices = attention_mask.shape[1] - 1 - attention_mask.flip((1,)).argmax(dim=1) # index of the last token of the input (according to att mask)
        last_backpack_hidden_states = backpack_hidden_states[torch.arange(backpack_hidden_states.shape[0]), last_toks_indices, :]

        logits = self.nli_head(last_backpack_hidden_states)

        if labels is not None:
            # Flatten the logits and labels, considering the attention mask
            flat_logits = logits
            flat_labels = labels.view(-1)

            loss = self.loss_func(flat_logits, flat_labels)
            return {'logits': logits, 'loss': loss}
        else:
            return {'logits': logits}
        

    def predict(self, input_ids=None, attention_mask=None):
        logits = self.forward(input_ids, attention_mask, labels=None)['logits']
        p = torch.argmax(logits, axis=1)
        labels = [self.config.id2label[index.item()] for index in p]
        return {'labels':labels, 'logits':logits}