IE101TW / models /span_extraction /span_for_ner.py
DeepLearning101's picture
Upload 2 files
d131d1a
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel
from transformers.models.albert.modeling_albert import AlbertPreTrainedModel, AlbertModel
from transformers.models.megatron_bert.modeling_megatron_bert import MegatronBertPreTrainedModel, MegatronBertModel
from models.basic_modules.linears import PoolerEndLogits, PoolerStartLogits
from torch.nn import CrossEntropyLoss
from loss.focal_loss import FocalLoss
from loss.label_smoothing import LabelSmoothingCrossEntropy
class BertSpanForNer(BertPreTrainedModel):
def __init__(self, config,):
super(BertSpanForNer, self).__init__(config)
self.soft_label = config.soft_label
self.num_labels = config.num_labels
self.loss_type = config.loss_type
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
if self.soft_label:
self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
else:
self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None):
outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
start_logits = self.start_fc(sequence_output)
if start_positions is not None and self.training:
if self.soft_label:
batch_size = input_ids.size(0)
seq_len = input_ids.size(1)
label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
label_logits.zero_()
label_logits = label_logits.to(input_ids.device)
label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
else:
label_logits = start_positions.unsqueeze(2).float()
else:
label_logits = F.softmax(start_logits, -1)
if not self.soft_label:
label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
end_logits = self.end_fc(sequence_output, label_logits)
outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None:
assert self.loss_type in ["lsr", "focal", "ce"]
if self.loss_type =="lsr":
loss_fct = LabelSmoothingCrossEntropy()
elif self.loss_type == "focal":
loss_fct = FocalLoss()
else:
loss_fct = CrossEntropyLoss()
start_logits = start_logits.view(-1, self.num_labels)
end_logits = end_logits.view(-1, self.num_labels)
active_loss = attention_mask.view(-1) == 1
active_start_logits = start_logits[active_loss]
active_end_logits = end_logits[active_loss]
active_start_labels = start_positions.view(-1)[active_loss]
active_end_labels = end_positions.view(-1)[active_loss]
start_loss = loss_fct(active_start_logits, active_start_labels)
end_loss = loss_fct(active_end_logits, active_end_labels)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs
class RobertaSpanForNer(RobertaPreTrainedModel):
def __init__(self, config,):
super(RobertaSpanForNer, self).__init__(config)
self.soft_label = config.soft_label
self.num_labels = config.num_labels
self.loss_type = config.loss_type
self.roberta = RobertaModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
if self.soft_label:
self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
else:
self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None):
outputs = self.roberta(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
start_logits = self.start_fc(sequence_output)
if start_positions is not None and self.training:
if self.soft_label:
batch_size = input_ids.size(0)
seq_len = input_ids.size(1)
label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
label_logits.zero_()
label_logits = label_logits.to(input_ids.device)
label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
else:
label_logits = start_positions.unsqueeze(2).float()
else:
label_logits = F.softmax(start_logits, -1)
if not self.soft_label:
label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
end_logits = self.end_fc(sequence_output, label_logits)
outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None:
assert self.loss_type in ["lsr", "focal", "ce"]
if self.loss_type =="lsr":
loss_fct = LabelSmoothingCrossEntropy()
elif self.loss_type == "focal":
loss_fct = FocalLoss()
else:
loss_fct = CrossEntropyLoss()
start_logits = start_logits.view(-1, self.num_labels)
end_logits = end_logits.view(-1, self.num_labels)
active_loss = attention_mask.view(-1) == 1
active_start_logits = start_logits[active_loss]
active_end_logits = end_logits[active_loss]
active_start_labels = start_positions.view(-1)[active_loss]
active_end_labels = end_positions.view(-1)[active_loss]
start_loss = loss_fct(active_start_logits, active_start_labels)
end_loss = loss_fct(active_end_logits, active_end_labels)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs
class AlbertSpanForNer(AlbertPreTrainedModel):
def __init__(self, config,):
super(AlbertSpanForNer, self).__init__(config)
self.soft_label = config.soft_label
self.num_labels = config.num_labels
self.loss_type = config.loss_type
self.bert = AlbertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
if self.soft_label:
self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
else:
self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None):
outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
start_logits = self.start_fc(sequence_output)
if start_positions is not None and self.training:
if self.soft_label:
batch_size = input_ids.size(0)
seq_len = input_ids.size(1)
label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
label_logits.zero_()
label_logits = label_logits.to(input_ids.device)
label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
else:
label_logits = start_positions.unsqueeze(2).float()
else:
label_logits = F.softmax(start_logits, -1)
if not self.soft_label:
label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
end_logits = self.end_fc(sequence_output, label_logits)
outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None:
assert self.loss_type in ["lsr","focal","ce"]
if self.loss_type =="lsr":
loss_fct = LabelSmoothingCrossEntropy()
elif self.loss_type == "focal":
loss_fct = FocalLoss()
else:
loss_fct = CrossEntropyLoss()
start_logits = start_logits.view(-1, self.num_labels)
end_logits = end_logits.view(-1, self.num_labels)
active_loss = attention_mask.view(-1) == 1
active_start_logits = start_logits[active_loss]
active_start_labels = start_positions.view(-1)[active_loss]
active_end_logits = end_logits[active_loss]
active_end_labels = end_positions.view(-1)[active_loss]
start_loss = loss_fct(active_start_logits, active_start_labels)
end_loss = loss_fct(active_end_logits, active_end_labels)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs
class MegatronBertSpanForNer(MegatronBertPreTrainedModel):
def __init__(self, config,):
super(BertSpanForNer, self).__init__(config)
# self.soft_label = config.soft_label
self.soft_label = True
self.num_labels = config.num_labels
# self.loss_type = config.loss_type
self.loss_type = "ce"
self.bert = MegatronBertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
if self.soft_label:
self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
else:
self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None):
outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
start_logits = self.start_fc(sequence_output)
if start_positions is not None and self.training:
if self.soft_label:
batch_size = input_ids.size(0)
seq_len = input_ids.size(1)
label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
label_logits.zero_()
label_logits = label_logits.to(input_ids.device)
label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
else:
label_logits = start_positions.unsqueeze(2).float()
else:
label_logits = F.softmax(start_logits, -1)
if not self.soft_label:
label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
end_logits = self.end_fc(sequence_output, label_logits)
outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None:
assert self.loss_type in ["lsr", "focal", "ce"]
if self.loss_type =="lsr":
loss_fct = LabelSmoothingCrossEntropy()
elif self.loss_type == "focal":
loss_fct = FocalLoss()
else:
loss_fct = CrossEntropyLoss()
start_logits = start_logits.view(-1, self.num_labels)
end_logits = end_logits.view(-1, self.num_labels)
active_loss = attention_mask.view(-1) == 1
active_start_logits = start_logits[active_loss]
active_end_logits = end_logits[active_loss]
active_start_labels = start_positions.view(-1)[active_loss]
active_end_labels = end_positions.view(-1)[active_loss]
start_loss = loss_fct(active_start_logits, active_start_labels)
end_loss = loss_fct(active_end_logits, active_end_labels)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs