DeepLearning101's picture
Upload 4 files
b0ebb46
raw
history blame
No virus
8.57 kB
# -*- coding: utf-8 -*-
# @Time : 2021/8/19 10:54 上午
# @Author : JianingWang
# @File : classification.py
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from transformers import RobertaModel
from transformers.activations import ACT2FN
from transformers.models.electra import ElectraModel
from transformers.models.roformer import RoFormerModel
from transformers.models.albert import AlbertModel
from transformers.models.bert import BertModel, BertPreTrainedModel
from transformers.models.deberta_v2 import DebertaV2Model, DebertaV2PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.roberta import RobertaPreTrainedModel
from transformers.models.bert.modeling_bert import BertForSequenceClassification
from transformers.models.megatron_bert import MegatronBertPreTrainedModel, MegatronBertModel
PRETRAINED_MODEL_MAP = {
"bert": BertPreTrainedModel,
"deberta-v2": DebertaV2PreTrainedModel,
"roberta": RobertaPreTrainedModel,
"erlangshen": MegatronBertPreTrainedModel
}
class BertPooler(nn.Module):
def __init__(self, hidden_size, hidden_act, hidden_dropout_prob):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
# self.activation = nn.Tanh()
self.activation = ACT2FN[hidden_act]
# self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, features):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
# x = self.dropout(x)
x = self.dense(x)
x = self.activation(x)
return x
def build_cls_model(config):
BaseClass = PRETRAINED_MODEL_MAP[config.model_type]
class BertForClassification(BaseClass):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.model_type = config.model_type
self.problem_type = config.problem_type
if self.model_type == "bert":
self.bert = BertModel(config)
elif self.model_type == "albert":
self.albert = AlbertModel(config)
# elif self.model_type == "chinesebert":
# self.bert = ChineseBertModel(config)
elif self.model_type == "roformer":
self.roformer = RoFormerModel(config)
elif self.model_type == "electra":
self.electra = ElectraModel(config)
elif self.model_type == "deberta-v2":
self.deberta = DebertaV2Model(config)
elif self.model_type == "roberta":
self.roberta = RobertaModel(config)
elif self.model_type == "erlangshen":
self.bert = MegatronBertModel(config)
self.pooler = BertPooler(config.hidden_size, config.hidden_act, config.hidden_dropout_prob)
if hasattr(config, "cls_dropout_rate"):
cls_dropout_rate = config.cls_dropout_rate
else:
cls_dropout_rate = config.hidden_dropout_prob
self.dropout = nn.Dropout(cls_dropout_rate)
add_feature_dims = config.additional_feature_dims if hasattr(config, "additional_feature_dims") else 0
# self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
cls_hidden = config.hidden_size + add_feature_dims
if hasattr(config, "is_relation_task"):
cls_hidden = config.hidden_size * 2
self.classifier = nn.Linear(cls_hidden, config.num_labels)
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
pseudo_label=None,
pinyin_ids=None,
additional_features=None
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logits, outputs = None, None
inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids,
"head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions,
"output_hidden_states": output_hidden_states, "return_dict": return_dict, "pinyin_ids": pinyin_ids}
inputs = {k: v for k, v in inputs.items() if v is not None}
if self.model_type == "chinesebert":
outputs = self.bert(**inputs)
elif self.model_type == "bert":
outputs = self.bert(**inputs)
elif self.model_type == "albert":
outputs = self.albert(**inputs)
elif self.model_type == "electra":
outputs = self.electra(**inputs)
elif self.model_type == "roformer":
outputs = self.roformer(**inputs)
elif self.model_type == "deberta-v2":
outputs = self.deberta(**inputs)
elif self.model_type == "roberta":
outputs = self.roberta(**inputs)
elif self.model_type == "erlangshen":
outputs = self.bert(**inputs)
if hasattr(self.config, "is_relation_task"):
w = torch.logical_and(input_ids >= min(self.config.start_token_ids), input_ids <= max(self.config.start_token_ids))
start_index = w.nonzero()[:, 1].view(-1, 2)
# <start_entity> + <end_entity> 进分类
pooler_output = torch.cat([torch.cat([x[y[0], :], x[y[1], :]]).unsqueeze(0) for x, y in zip(outputs.last_hidden_state, start_index)])
# [CLS] + <start_entity> + <end_entity> 进分类
# pooler_output = torch.cat([torch.cat([z, x[y[0], :], x[y[1], :]]).unsqueeze(0) for x, y, z in zip(outputs.last_hidden_state, start_index, outputs.last_hidden_state[:, 0])])
elif "pooler_output" in outputs:
pooler_output = outputs.pooler_output
else:
pooler_output = self.pooler(outputs[0])
pooler_output = self.dropout(pooler_output)
# pooler_output = self.LayerNorm(pooler_output)
if additional_features is not None:
pooler_output = torch.cat((pooler_output, additional_features), dim=1)
logits = self.classifier(pooler_output)
loss = None
if labels is not None:
if self.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.float().view(-1, self.num_labels))
# elif self.problem_type in ["single_label_classification"] or hasattr(self.config, "is_relation_task"):
else:
# loss_fct = FocalLoss()
loss_fct = CrossEntropyLoss()
# 伪标签
if pseudo_label is not None:
train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1]
train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1]
train_loss = loss_fct(train_logits.view(-1, self.num_labels), train_labels.view(-1)) if train_labels.nelement() else 0
pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels), pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0
loss = 0.9 * train_loss + 0.1 * pseudo_loss
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return BertForClassification