File size: 5,480 Bytes
617dc35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from abc import ABCMeta
import numpy as np
import torch
from transformers.pytorch_utils import nn
import torch.nn.functional as F
from src.configuration import BertABSAConfig
from transformers import BertModel, BertForSequenceClassification, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput


class BertBaseForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
    config_class = BertABSAConfig

    def __init__(self, config):
        super(BertBaseForSequenceClassification, self).__init__(config)
        self.num_classes = config.num_classes
        self.embed_dim = config.embed_dim
        self.dropout = nn.Dropout(config.dropout_rate)

        self.bert = BertForSequenceClassification.from_pretrained('bert-base-uncased',  # noqa
                                                                  output_hidden_states=False,  # noqa
                                                                  output_attentions=False,  # noqa
                                                                  num_labels=self.num_classes)  # noqa
        print("BERT Model Loaded")

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask,
                        token_type_ids=token_type_ids, labels=labels)
        return out


class BertLSTMForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
    config_class = BertABSAConfig

    def __init__(self, config):
        super(BertLSTMForSequenceClassification, self).__init__(config)
        self.num_classes = config.num_classes
        self.embed_dim = config.embed_dim
        self.num_layers = config.num_layers
        self.hidden_dim_lstm = config.hidden_dim_lstm
        self.dropout = nn.Dropout(config.dropout_rate)

        self.bert = BertModel.from_pretrained('bert-base-uncased',
                                              output_hidden_states=True,
                                              output_attentions=False)
        print("BERT Model Loaded")
        self.lstm = nn.LSTM(self.embed_dim, self.hidden_dim_lstm, batch_first=True)  # noqa
        self.fc = nn.Linear(self.hidden_dim_lstm, self.num_classes)

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        hidden_states = bert_output["hidden_states"]

        hidden_states = torch.stack([hidden_states[layer_i][:, 0].squeeze()
                                     for layer_i in range(0, self.num_layers)], dim=-1)  # noqa
        hidden_states = hidden_states.view(-1, self.num_layers, self.embed_dim)
        out, _ = self.lstm(hidden_states, None)
        out = self.dropout(out[:, -1, :])
        logits = self.fc(out)
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
        out = SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=bert_output.hidden_states,
            attentions=bert_output.attentions,
        )
        return out


class BertAttentionForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
    config_class = BertABSAConfig

    def __init__(self, config):
        super(BertAttentionForSequenceClassification, self).__init__(config)
        self.num_classes = config.num_classes
        self.embed_dim = config.embed_dim
        self.num_layers = config.num_layers
        self.fc_hidden = config.fc_hidden
        self.dropout = nn.Dropout(config.dropout_rate)

        self.bert = BertModel.from_pretrained('bert-base-uncased',
                                              output_hidden_states=True,
                                              output_attentions=False)
        print("BERT Model Loaded")

        q_t = np.random.normal(loc=0.0, scale=0.1, size=(1, self.embed_dim))
        self.q = nn.Parameter(torch.from_numpy(q_t)).float().to(self.device)
        w_ht = np.random.normal(loc=0.0, scale=0.1, size=(self.embed_dim, self.fc_hidden))  # noqa
        self.w_h = nn.Parameter(torch.from_numpy(w_ht)).float().to(self.device)

        self.fc = nn.Linear(self.fc_hidden, self.num_classes)

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        hidden_states = bert_output["hidden_states"]

        hidden_states = torch.stack([hidden_states[layer_i][:, 0].squeeze()
                                     for layer_i in range(0, self.num_layers)], dim=-1)  # noqa
        hidden_states = hidden_states.view(-1, self.num_layers, self.embed_dim)
        out = self.attention(hidden_states)
        out = self.dropout(out)
        logits = self.fc(out)
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
        out = SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=bert_output.hidden_states,
            attentions=bert_output.attentions,
        )
        return out

    def attention(self, h):
        v = torch.matmul(self.q, h.transpose(-2, -1)).squeeze(1)
        v = F.softmax(v, -1)
        v_temp = torch.matmul(v.unsqueeze(1), h).transpose(-2, -1)
        v = torch.matmul(self.w_h.transpose(1, 0), v_temp).squeeze(2)
        return v