|
import torch |
|
import logging |
|
|
|
from torch import nn |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union |
|
|
|
from transformers import PreTrainedModel, AutoModel, AutoConfig |
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
|
|
from .configuration_gector import GectorConfig |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
GECTOR_PRETRAINED_BASE_MODEL_ARCHIVE_LIST = [ |
|
"bert-base-cased", |
|
"bert-large-cased", |
|
"roberta-base", |
|
"roberta-large", |
|
"xlnet-base-cased", |
|
"xlnet-large-cased", |
|
"deberta-base-cased", |
|
"deberta-large-cased", |
|
] |
|
|
|
|
|
@dataclass |
|
class GectorTokenClassifierOutput(TokenClassifierOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
logits_detect: torch.FloatTensor = None |
|
class_probabilities_detect: torch.FloatTensor = None |
|
logits_correct: torch.FloatTensor = None |
|
class_probabilities_correct: torch.FloatTensor = None |
|
max_error_probabilities: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
class GectorModel(PreTrainedModel): |
|
config_class = GectorConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
special_tokens_fix = config.special_tokens_fix |
|
|
|
config = AutoConfig.from_pretrained(config.model_id) |
|
self.encoder_model = AutoModel.from_config(config) |
|
|
|
if special_tokens_fix: |
|
self.encoder_model.resize_token_embeddings(config.vocab_size + 1) |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.encoder_model.forward(*args, **kwargs) |
|
|
|
|
|
class GectorForTokenClassification(PreTrainedModel): |
|
config_class = GectorConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_detect_tags = config.num_detect_tags |
|
self.num_correct_tags = config.num_correct_tags |
|
|
|
self.text_field_embedder = GectorModel(config) |
|
self.embedding_size = self.text_field_embedder.encoder_model.config.hidden_size |
|
|
|
self.dropout = nn.Dropout(config.classifier_dropout) |
|
|
|
self.detect_proj_layer = nn.Linear(self.embedding_size, self.num_detect_tags) |
|
self.correct_proj_layer = nn.Linear(self.embedding_size, self.num_correct_tags) |
|
|
|
self.delete_confidence = config.delete_confidence |
|
self.additional_confidence = config.additional_confidence |
|
self.incorrect_index = config.detect_label2id.get("$INCORRECT") |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
word_offsets: Optional[torch.LongTensor] = None, |
|
word_mask: Optional[torch.LongTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], GectorTokenClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
outputs = self.text_field_embedder( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = outputs[0] |
|
|
|
|
|
|
|
|
|
if word_offsets is not None: |
|
indices = word_offsets.unsqueeze(-1).expand( |
|
-1, -1, sequence_output.size(-1) |
|
) |
|
sequence_output = torch.gather(sequence_output, 1, indices) |
|
batch_size, sequence_length = sequence_output.size()[0:2] |
|
|
|
logits_detect = self.detect_proj_layer(sequence_output) |
|
logits_correct = self.correct_proj_layer(self.dropout(sequence_output)) |
|
|
|
class_probabilities_correct = nn.functional.softmax( |
|
logits_correct, dim=-1 |
|
).view([batch_size, sequence_length, self.num_correct_tags]) |
|
class_probabilities_detect = nn.functional.softmax(logits_detect, dim=-1).view( |
|
[batch_size, sequence_length, self.num_detect_tags] |
|
) |
|
max_error_probabilities = torch.max( |
|
class_probabilities_detect[:, :, self.incorrect_index] * word_mask, |
|
dim=-1, |
|
)[0] |
|
probability_change = [self.additional_confidence, self.delete_confidence] + [ |
|
0 |
|
] * (self.num_correct_tags - 2) |
|
class_probabilities_correct += ( |
|
torch.FloatTensor(probability_change) |
|
.repeat((batch_size, sequence_length, 1)) |
|
.to(self.device) |
|
) |
|
|
|
loss = None |
|
if labels is not None: |
|
detect_labels, correct_labels = torch.tensor_split(labels, 2, dim=-1) |
|
|
|
detect_labels[detect_labels == self.config.detect_pad_token_id] = -100 |
|
correct_labels[correct_labels == self.config.correct_pad_token_id] = -100 |
|
|
|
detect_loss_fct = nn.CrossEntropyLoss() |
|
loss_detect = detect_loss_fct( |
|
logits_detect.view(-1, self.config.num_detect_tags), |
|
detect_labels.view(-1), |
|
) |
|
|
|
correct_loss_fct = nn.CrossEntropyLoss( |
|
label_smoothing=self.config.label_smoothing |
|
) |
|
loss_correct = correct_loss_fct( |
|
logits_correct.view(-1, self.config.num_correct_tags), |
|
correct_labels.view(-1), |
|
) |
|
loss = loss_detect + loss_correct |
|
|
|
if not return_dict: |
|
output = (logits_detect, logits_correct) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return GectorTokenClassifierOutput( |
|
loss=loss, |
|
logits_detect=logits_detect, |
|
class_probabilities_detect=class_probabilities_detect, |
|
logits_correct=logits_correct, |
|
class_probabilities_correct=class_probabilities_correct, |
|
max_error_probabilities=max_error_probabilities, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|