gector-xlnet-base-cased-5k / modelling_gector.py
ktzsh's picture
Upload folder using huggingface_hub
010f214
import torch
import logging
from torch import nn
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModel, AutoConfig
from transformers.modeling_outputs import TokenClassifierOutput
from .configuration_gector import GectorConfig
logger = logging.getLogger(__name__)
GECTOR_PRETRAINED_BASE_MODEL_ARCHIVE_LIST = [
"bert-base-cased",
"bert-large-cased",
"roberta-base",
"roberta-large",
"xlnet-base-cased",
"xlnet-large-cased",
"deberta-base-cased",
"deberta-large-cased",
]
@dataclass
class GectorTokenClassifierOutput(TokenClassifierOutput):
loss: Optional[torch.FloatTensor] = None
logits_detect: torch.FloatTensor = None
class_probabilities_detect: torch.FloatTensor = None
logits_correct: torch.FloatTensor = None
class_probabilities_correct: torch.FloatTensor = None
max_error_probabilities: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class GectorModel(PreTrainedModel):
config_class = GectorConfig
def __init__(self, config):
super().__init__(config)
special_tokens_fix = config.special_tokens_fix
config = AutoConfig.from_pretrained(config.model_id)
self.encoder_model = AutoModel.from_config(config)
if special_tokens_fix:
self.encoder_model.resize_token_embeddings(config.vocab_size + 1)
def forward(self, *args, **kwargs):
return self.encoder_model.forward(*args, **kwargs)
class GectorForTokenClassification(PreTrainedModel):
config_class = GectorConfig
def __init__(self, config):
super().__init__(config)
self.num_detect_tags = config.num_detect_tags
self.num_correct_tags = config.num_correct_tags
self.text_field_embedder = GectorModel(config)
self.embedding_size = self.text_field_embedder.encoder_model.config.hidden_size
self.dropout = nn.Dropout(config.classifier_dropout)
self.detect_proj_layer = nn.Linear(self.embedding_size, self.num_detect_tags)
self.correct_proj_layer = nn.Linear(self.embedding_size, self.num_correct_tags)
self.delete_confidence = config.delete_confidence
self.additional_confidence = config.additional_confidence
self.incorrect_index = config.detect_label2id.get("$INCORRECT")
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
word_offsets: Optional[torch.LongTensor] = None,
word_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], GectorTokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.text_field_embedder(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# If offsets are provided, the returned tensor will contain only the wordpiece
# embeddings at those positions, and (in particular) will contain one embedding
# per token. If offsets are not provided, the entire tensor of wordpiece embeddings
# will be returned.
if word_offsets is not None:
indices = word_offsets.unsqueeze(-1).expand(
-1, -1, sequence_output.size(-1)
)
sequence_output = torch.gather(sequence_output, 1, indices)
batch_size, sequence_length = sequence_output.size()[0:2]
logits_detect = self.detect_proj_layer(sequence_output)
logits_correct = self.correct_proj_layer(self.dropout(sequence_output))
class_probabilities_correct = nn.functional.softmax(
logits_correct, dim=-1
).view([batch_size, sequence_length, self.num_correct_tags])
class_probabilities_detect = nn.functional.softmax(logits_detect, dim=-1).view(
[batch_size, sequence_length, self.num_detect_tags]
)
max_error_probabilities = torch.max(
class_probabilities_detect[:, :, self.incorrect_index] * word_mask,
dim=-1,
)[0]
probability_change = [self.additional_confidence, self.delete_confidence] + [
0
] * (self.num_correct_tags - 2)
class_probabilities_correct += (
torch.FloatTensor(probability_change)
.repeat((batch_size, sequence_length, 1))
.to(self.device)
)
loss = None
if labels is not None:
detect_labels, correct_labels = torch.tensor_split(labels, 2, dim=-1)
# -100 is the default ignore_idx of CrossEntropyLoss
detect_labels[detect_labels == self.config.detect_pad_token_id] = -100
correct_labels[correct_labels == self.config.correct_pad_token_id] = -100
detect_loss_fct = nn.CrossEntropyLoss()
loss_detect = detect_loss_fct(
logits_detect.view(-1, self.config.num_detect_tags),
detect_labels.view(-1),
)
correct_loss_fct = nn.CrossEntropyLoss(
label_smoothing=self.config.label_smoothing
)
loss_correct = correct_loss_fct(
logits_correct.view(-1, self.config.num_correct_tags),
correct_labels.view(-1),
)
loss = loss_detect + loss_correct
if not return_dict:
output = (logits_detect, logits_correct) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return GectorTokenClassifierOutput(
loss=loss,
logits_detect=logits_detect,
class_probabilities_detect=class_probabilities_detect,
logits_correct=logits_correct,
class_probabilities_correct=class_probabilities_correct,
max_error_probabilities=max_error_probabilities,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)