File size: 4,148 Bytes
3bc4816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from torch import nn
from transformers import RobertaPreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel

from utils import batched_index_select


class DependencyRobertaForTokenClassification(RobertaPreTrainedModel):
    config_class = RobertaConfig  # type: ignore

    def __init__(self, config):
        super().__init__(config)
        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.u_a = nn.Linear(768, 768)
        self.w_a = nn.Linear(768, 768)
        self.v_a_inv = nn.Linear(768, 1, bias=False)
        self.criterion = nn.NLLLoss()
        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs,
    ):
        loss = 0.0
        output = self.roberta(
            input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
        )[0]
        batch_size, seq_len, _ = output.size()

        parent_prob_table = []
        for i in range(0, seq_len):
            target = output[:, i, :].expand(seq_len, batch_size, -1).transpose(0, 1)
            mask = output.eq(target)[:, :, 0].unsqueeze(2)
            p_head = self.attention(output, target, mask)
            if labels is not None:
                current_loss = self.criterion(p_head.squeeze(-1), labels[:, i])
                if not torch.all(labels[:, i] == -100):
                    loss += current_loss
            parent_prob_table.append(torch.exp(p_head))

        parent_prob_table = torch.cat((parent_prob_table), dim=2).data.transpose(1, 2)
        prob, topi = parent_prob_table.topk(k=1, dim=2)
        preds = topi.squeeze(-1)
        loss = loss / seq_len
        output = TokenClassifierOutput(loss=loss, logits=preds)

        if labels is not None:
            return output, preds, parent_prob_table, labels
        else:
            return output, preds, parent_prob_table

    def attention(self, source, target, mask=None):
        function_g = self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target)))
        if mask is not None:
            function_g.masked_fill_(mask, -1e4)
        return nn.functional.log_softmax(function_g, dim=1)


class LabelRobertaForTokenClassification(RobertaPreTrainedModel):
    config_class = RobertaConfig  # type: ignore

    def __init__(self, config):
        super().__init__(config)
        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.num_labels = 33
        self.hidden = nn.Linear(768 * 2, 768)
        self.relu = nn.ReLU()
        self.out = nn.Linear(768, self.num_labels)
        self.loss_fct = nn.CrossEntropyLoss()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs,
    ):
        loss = 0.0
        output = self.roberta(
            input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
        )[0]
        batch_size, seq_len, _ = output.size()
        logits = []
        for i in range(seq_len):
            current_token = output[:, i, :]
            connected_with_index = kwargs["head_labels"][:, i]
            connected_with_index[connected_with_index == -100] = 0
            connected_with_embedding = batched_index_select(
                output.clone(), 1, connected_with_index.clone()
            )
            combined_embeddings = torch.cat(
                (current_token, connected_with_embedding.squeeze(1)), -1
            )
            pred = self.out(self.relu(self.hidden(combined_embeddings)))
            pred = pred.view(-1, self.num_labels)
            logits.append(pred)
            if labels is not None:
                current_loss = self.loss_fct(pred, labels[:, i].view(-1))
                if not torch.all(labels[:, i] == -100):
                    loss += current_loss

        loss = loss / seq_len
        logits = torch.stack(logits, dim=1)
        output = TokenClassifierOutput(loss=loss, logits=logits)
        return output