import torch import logging from torch import nn from dataclasses import dataclass from typing import Optional, Tuple, Union from transformers import PreTrainedModel, AutoModel, AutoConfig from transformers.modeling_outputs import TokenClassifierOutput from .configuration_gector import GectorConfig logger = logging.getLogger(__name__) GECTOR_PRETRAINED_BASE_MODEL_ARCHIVE_LIST = [ "bert-base-cased", "bert-large-cased", "roberta-base", "roberta-large", "xlnet-base-cased", "xlnet-large-cased", "deberta-base-cased", "deberta-large-cased", ] @dataclass class GectorTokenClassifierOutput(TokenClassifierOutput): loss: Optional[torch.FloatTensor] = None logits_detect: torch.FloatTensor = None class_probabilities_detect: torch.FloatTensor = None logits_correct: torch.FloatTensor = None class_probabilities_correct: torch.FloatTensor = None max_error_probabilities: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None class GectorModel(PreTrainedModel): config_class = GectorConfig def __init__(self, config): super().__init__(config) special_tokens_fix = config.special_tokens_fix config = AutoConfig.from_pretrained(config.model_id) self.encoder_model = AutoModel.from_config(config) if special_tokens_fix: self.encoder_model.resize_token_embeddings(config.vocab_size + 1) def forward(self, *args, **kwargs): return self.encoder_model.forward(*args, **kwargs) class GectorForTokenClassification(PreTrainedModel): config_class = GectorConfig def __init__(self, config): super().__init__(config) self.num_detect_tags = config.num_detect_tags self.num_correct_tags = config.num_correct_tags self.text_field_embedder = GectorModel(config) self.embedding_size = self.text_field_embedder.encoder_model.config.hidden_size self.dropout = nn.Dropout(config.classifier_dropout) self.detect_proj_layer = nn.Linear(self.embedding_size, self.num_detect_tags) self.correct_proj_layer = nn.Linear(self.embedding_size, self.num_correct_tags) self.delete_confidence = config.delete_confidence self.additional_confidence = config.additional_confidence self.incorrect_index = config.detect_label2id.get("$INCORRECT") # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, word_offsets: Optional[torch.LongTensor] = None, word_mask: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], GectorTokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.text_field_embedder( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] # If offsets are provided, the returned tensor will contain only the wordpiece # embeddings at those positions, and (in particular) will contain one embedding # per token. If offsets are not provided, the entire tensor of wordpiece embeddings # will be returned. if word_offsets is not None: indices = word_offsets.unsqueeze(-1).expand( -1, -1, sequence_output.size(-1) ) sequence_output = torch.gather(sequence_output, 1, indices) batch_size, sequence_length = sequence_output.size()[0:2] logits_detect = self.detect_proj_layer(sequence_output) logits_correct = self.correct_proj_layer(self.dropout(sequence_output)) class_probabilities_correct = nn.functional.softmax( logits_correct, dim=-1 ).view([batch_size, sequence_length, self.num_correct_tags]) class_probabilities_detect = nn.functional.softmax(logits_detect, dim=-1).view( [batch_size, sequence_length, self.num_detect_tags] ) max_error_probabilities = torch.max( class_probabilities_detect[:, :, self.incorrect_index] * word_mask, dim=-1, )[0] probability_change = [self.additional_confidence, self.delete_confidence] + [ 0 ] * (self.num_correct_tags - 2) class_probabilities_correct += ( torch.FloatTensor(probability_change) .repeat((batch_size, sequence_length, 1)) .to(self.device) ) loss = None if labels is not None: detect_labels, correct_labels = torch.tensor_split(labels, 2, dim=-1) # -100 is the default ignore_idx of CrossEntropyLoss detect_labels[detect_labels == self.config.detect_pad_token_id] = -100 correct_labels[correct_labels == self.config.correct_pad_token_id] = -100 detect_loss_fct = nn.CrossEntropyLoss() loss_detect = detect_loss_fct( logits_detect.view(-1, self.config.num_detect_tags), detect_labels.view(-1), ) correct_loss_fct = nn.CrossEntropyLoss( label_smoothing=self.config.label_smoothing ) loss_correct = correct_loss_fct( logits_correct.view(-1, self.config.num_correct_tags), correct_labels.view(-1), ) loss = loss_detect + loss_correct if not return_dict: output = (logits_detect, logits_correct) + outputs[2:] return ((loss,) + output) if loss is not None else output return GectorTokenClassifierOutput( loss=loss, logits_detect=logits_detect, class_probabilities_detect=class_probabilities_detect, logits_correct=logits_correct, class_probabilities_correct=class_probabilities_correct, max_error_probabilities=max_error_probabilities, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )