File size: 8,574 Bytes
b0ebb46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# -*- coding: utf-8 -*-
# @Time    : 2021/8/19 10:54 上午
# @Author  : JianingWang
# @File    : classification.py
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from transformers import RobertaModel
from transformers.activations import ACT2FN
from transformers.models.electra import ElectraModel
from transformers.models.roformer import RoFormerModel
from transformers.models.albert import AlbertModel
from transformers.models.bert import BertModel, BertPreTrainedModel
from transformers.models.deberta_v2 import DebertaV2Model, DebertaV2PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.roberta import RobertaPreTrainedModel
from transformers.models.bert.modeling_bert import BertForSequenceClassification
from transformers.models.megatron_bert import MegatronBertPreTrainedModel, MegatronBertModel

PRETRAINED_MODEL_MAP = {
    "bert": BertPreTrainedModel,
    "deberta-v2": DebertaV2PreTrainedModel,
    "roberta": RobertaPreTrainedModel,
    "erlangshen": MegatronBertPreTrainedModel
}


class BertPooler(nn.Module):
    def __init__(self, hidden_size, hidden_act, hidden_dropout_prob):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        # self.activation = nn.Tanh()
        self.activation = ACT2FN[hidden_act]
        # self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, features):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        # x = self.dropout(x)
        x = self.dense(x)
        x = self.activation(x)
        return x


def build_cls_model(config):
    BaseClass = PRETRAINED_MODEL_MAP[config.model_type]

    class BertForClassification(BaseClass):

        def __init__(self, config):
            super().__init__(config)
            self.num_labels = config.num_labels
            self.config = config
            self.model_type = config.model_type
            self.problem_type = config.problem_type

            if self.model_type == "bert":
                self.bert = BertModel(config)
            elif self.model_type == "albert":
                self.albert = AlbertModel(config)
            # elif self.model_type == "chinesebert":
            #     self.bert = ChineseBertModel(config)
            elif self.model_type == "roformer":
                self.roformer = RoFormerModel(config)
            elif self.model_type == "electra":
                self.electra = ElectraModel(config)
            elif self.model_type == "deberta-v2":
                self.deberta = DebertaV2Model(config)
            elif self.model_type == "roberta":
                self.roberta = RobertaModel(config)
            elif self.model_type == "erlangshen":
                self.bert = MegatronBertModel(config)
            self.pooler = BertPooler(config.hidden_size, config.hidden_act, config.hidden_dropout_prob)
            if hasattr(config, "cls_dropout_rate"):
                cls_dropout_rate = config.cls_dropout_rate
            else:
                cls_dropout_rate = config.hidden_dropout_prob
            self.dropout = nn.Dropout(cls_dropout_rate)
            add_feature_dims = config.additional_feature_dims if hasattr(config, "additional_feature_dims") else 0
            # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
            cls_hidden = config.hidden_size + add_feature_dims
            if hasattr(config, "is_relation_task"):
                cls_hidden = config.hidden_size * 2
            self.classifier = nn.Linear(cls_hidden, config.num_labels)

            self.init_weights()

        def forward(
                self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                pseudo_label=None,
                pinyin_ids=None,
                additional_features=None
        ):
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict
            logits, outputs = None, None
            inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids,
                      "head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions,
                      "output_hidden_states": output_hidden_states, "return_dict": return_dict, "pinyin_ids": pinyin_ids}
            inputs = {k: v for k, v in inputs.items() if v is not None}
            if self.model_type == "chinesebert":
                outputs = self.bert(**inputs)
            elif self.model_type == "bert":
                outputs = self.bert(**inputs)
            elif self.model_type == "albert":
                outputs = self.albert(**inputs)
            elif self.model_type == "electra":
                outputs = self.electra(**inputs)
            elif self.model_type == "roformer":
                outputs = self.roformer(**inputs)
            elif self.model_type == "deberta-v2":
                outputs = self.deberta(**inputs)
            elif self.model_type == "roberta":
                outputs = self.roberta(**inputs)
            elif self.model_type == "erlangshen":
                outputs = self.bert(**inputs)

            if hasattr(self.config, "is_relation_task"):
                w = torch.logical_and(input_ids >= min(self.config.start_token_ids), input_ids <= max(self.config.start_token_ids))
                start_index = w.nonzero()[:, 1].view(-1, 2)
                # <start_entity> + <end_entity> 进分类
                pooler_output = torch.cat([torch.cat([x[y[0], :], x[y[1], :]]).unsqueeze(0) for x, y in zip(outputs.last_hidden_state, start_index)])
                # [CLS] + <start_entity> + <end_entity> 进分类
                # pooler_output = torch.cat([torch.cat([z, x[y[0], :], x[y[1], :]]).unsqueeze(0) for x, y, z in zip(outputs.last_hidden_state, start_index, outputs.last_hidden_state[:, 0])])

            elif "pooler_output" in outputs:
                pooler_output = outputs.pooler_output
            else:
                pooler_output = self.pooler(outputs[0])
            pooler_output = self.dropout(pooler_output)
            # pooler_output = self.LayerNorm(pooler_output)
            if additional_features is not None:
                pooler_output = torch.cat((pooler_output, additional_features), dim=1)
            logits = self.classifier(pooler_output)

            loss = None
            if labels is not None:
                if self.problem_type == "regression":
                    loss_fct = MSELoss()
                    if self.num_labels == 1:
                        loss = loss_fct(logits.squeeze(), labels.squeeze())
                    else:
                        loss = loss_fct(logits, labels)
                elif self.problem_type == "multi_label_classification":
                    loss_fct = BCEWithLogitsLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.float().view(-1, self.num_labels))
                # elif self.problem_type in ["single_label_classification"] or hasattr(self.config, "is_relation_task"):
                else:
                    # loss_fct = FocalLoss()
                    loss_fct = CrossEntropyLoss()
                    # 伪标签
                    if pseudo_label is not None:
                        train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1]
                        train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1]
                        train_loss = loss_fct(train_logits.view(-1, self.num_labels), train_labels.view(-1)) if train_labels.nelement() else 0
                        pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels), pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0
                        loss = 0.9 * train_loss + 0.1 * pseudo_loss
                    else:
                        loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            return SequenceClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )

    return BertForClassification