IE101TW / models /sequence_matching /fusion_siamese.py
DeepLearning101's picture
Upload fusion_siamese.py
4cda815
raw
history blame
10.8 kB
# -*- coding: utf-8 -*-
# @Time : 2022/4/21 5:30 下午
# @Author : JianingWang
# @File : fusion_siamese.py
from typing import Optional
import torch
import numpy as np
import torch.nn as nn
from dataclasses import dataclass
from torch.nn import BCEWithLogitsLoss
from transformers import MegatronBertModel, MegatronBertPreTrainedModel
from transformers.file_utils import ModelOutput
from transformers.models.bert import BertPreTrainedModel, BertModel
from transformers.activations import ACT2FN
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from transformers.modeling_outputs import SequenceClassifierOutput
from loss.focal_loss import FocalLoss
# from roformer import RoFormerPreTrainedModel, RoFormerModel
class BertPooler(nn.Module):
def __init__(self, hidden_size, hidden_act):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
# self.activation = nn.Tanh()
self.activation = ACT2FN[hidden_act]
# self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, features):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
# x = self.dropout(x)
x = self.dense(x)
x = self.activation(x)
return x
class BertForFusionSiamese(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.hidden_size = config.hidden_size
self.hidden_act = config.hidden_act
self.bert_poor = BertPooler(self.hidden_size, self.hidden_act)
self.dense_1 = nn.Linear(self.hidden_size, self.hidden_size)
self.dense_2 = nn.Linear(self.hidden_size, self.hidden_size)
if hasattr(config, "cls_dropout_rate"):
cls_dropout_rate = config.cls_dropout_rate
else:
cls_dropout_rate = config.hidden_dropout_prob
self.dropout = nn.Dropout(cls_dropout_rate)
self.classifier = nn.Linear(3 * self.hidden_size, config.num_labels)
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
pseudo_label=None,
segment_spans=None,
pseuso_proba=None
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logits, outputs = None, None
inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids,
"position_ids": position_ids,
"head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions,
"output_hidden_states": output_hidden_states, "return_dict": return_dict}
inputs = {k: v for k, v in inputs.items() if v is not None}
outputs = self.bert(**inputs)
if "sequence_output" in outputs:
sequence_output = outputs.sequence_output # [bz, seq_len, dim]
else:
sequence_output = outputs[0] # [bz, seq_len, dim]
cls_output = self.bert_poor(sequence_output) # [bz, dim]
if segment_spans is not None:
# 如果输入的是两个segment,则分别进行平均池化
seg1_embeddings, seg2_embeddings = list(), list()
for ei, sentence_embeddings in enumerate(sequence_output):
# sentence_embedding: [seq_len, dim]
seg1_start, seg1_end, seg2_start, seg2_end = segment_spans[ei]
# print("sentence_embeddings[seg1_start, seg1_end].shape=", sentence_embeddings[seg1_start, seg1_end].shape)
# print("torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape=", torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape)
seg1_embeddings.append(torch.mean(sentence_embeddings[seg1_start: seg1_end], 0)) # [dim]
seg2_embeddings.append(torch.mean(sentence_embeddings[seg2_start: seg2_end], 0)) # [dim]
seg1_embeddings, seg2_embeddings = torch.stack(seg1_embeddings), torch.stack(seg2_embeddings) # [bz, dim]
# print("seg1_embeddings.shape=", seg1_embeddings.shape)
seg1_embeddings = self.bert_poor.activation(self.dense_1(seg1_embeddings))
seg2_embeddings = self.bert_poor.activation(self.dense_1(seg2_embeddings))
cls_output = torch.cat([cls_output, seg1_embeddings, seg2_embeddings], dim=-1) # [bz, 3*dim]
# cls_output = cls_output + seg1_embeddings + seg2_embeddings # [bz, dim]
pooler_output = self.dropout(cls_output)
# pooler_output = self.LayerNorm(pooler_output)
logits = self.classifier(pooler_output)
loss = None
if labels is not None:
# loss_fct = FocalLoss()
loss_fct = CrossEntropyLoss()
# 伪标签
if pseudo_label is not None:
train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1]
train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1]
train_loss = loss_fct(train_logits.view(-1, self.num_labels),
train_labels.view(-1)) if train_labels.nelement() else 0
pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels),
pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0
loss = 0.9 * train_loss + 0.1 * pseudo_loss
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class BertForWSC(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.hidden_size = config.hidden_size
self.hidden_act = config.hidden_act
self.bert_poor = BertPooler(self.hidden_size, self.hidden_act)
self.dense_1 = nn.Linear(self.hidden_size, self.hidden_size)
self.dense_2 = nn.Linear(self.hidden_size, self.hidden_size)
if hasattr(config, "cls_dropout_rate"):
cls_dropout_rate = config.cls_dropout_rate
else:
cls_dropout_rate = config.hidden_dropout_prob
self.dropout = nn.Dropout(cls_dropout_rate)
self.classifier = nn.Linear(2 * self.hidden_size, config.num_labels)
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
pseudo_label=None,
span=None,
pseuso_proba=None
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logits, outputs = None, None
inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids,
"position_ids": position_ids,
"head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions,
"output_hidden_states": output_hidden_states, "return_dict": return_dict}
inputs = {k: v for k, v in inputs.items() if v is not None}
outputs = self.bert(**inputs)
if "sequence_output" in outputs:
sequence_output = outputs.sequence_output # [bz, seq_len, dim]
else:
sequence_output = outputs[0] # [bz, seq_len, dim]
# cls_output = self.bert_poor(sequence_output) # [bz, dim]
# 如果输入的是两个span,则分别进行平均池化
seg1_embeddings, seg2_embeddings = list(), list()
# print("span=", span)
for ei, sentence_embeddings in enumerate(sequence_output):
# sentence_embedding: [seq_len, dim]
seg1_start, seg1_end, seg2_start, seg2_end = span[ei]
# print("sentence_embeddings[seg1_start, seg1_end].shape=", sentence_embeddings[seg1_start, seg1_end].shape)
# print("torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape=", torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape)
seg1_embeddings.append(torch.mean(sentence_embeddings[seg1_start+1: seg1_end], 0)) # [dim]
seg2_embeddings.append(torch.mean(sentence_embeddings[seg2_start+1: seg2_end], 0)) # [dim]
seg1_embeddings, seg2_embeddings = torch.stack(seg1_embeddings), torch.stack(seg2_embeddings) # [bz, dim]
# print("seg1_embeddings.shape=", seg1_embeddings.shape)
# seg1_embeddings = self.bert_poor.activation(self.dense_1(seg1_embeddings))
# seg2_embeddings = self.bert_poor.activation(self.dense_1(seg2_embeddings))
cls_output = torch.cat([seg1_embeddings, seg2_embeddings], dim=-1) # [bz, 3*dim]
# cls_output = cls_output + seg1_embeddings + seg2_embeddings # [bz, dim]
pooler_output = self.dropout(cls_output)
# pooler_output = self.LayerNorm(pooler_output)
logits = self.classifier(pooler_output)
loss = None
if labels is not None:
# loss_fct = FocalLoss()
loss_fct = CrossEntropyLoss()
# 伪标签
if pseudo_label is not None:
train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1]
train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1]
train_loss = loss_fct(train_logits.view(-1, self.num_labels),
train_labels.view(-1)) if train_labels.nelement() else 0
pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels),
pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0
loss = 0.9 * train_loss + 0.1 * pseudo_loss
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)