from typing import Any, Dict, List, Optional, Tuple, Union from torch import nn from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModel, BertPreTrainedModel from transformers.modeling_outputs import ModelOutput import torch def get_range_vector(size: int, device: int) -> torch.Tensor: """ Returns a range vector with the desired size, starting at 0. The CUDA implementation is meant to avoid copy data from CPU to GPU. """ return torch.arange(0, size, dtype=torch.long, device=device) class Seq2LabelsOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None detect_logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None max_error_probability: Optional[torch.FloatTensor] = None class Seq2LabelsModel(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.num_detect_classes = config.num_detect_classes self.label_smoothing = config.label_smoothing if config.load_pretrained: self.bert = AutoModel.from_pretrained(config.pretrained_name_or_path) bert_config = self.bert.config else: bert_config = AutoConfig.from_pretrained(config.pretrained_name_or_path) self.bert = AutoModel.from_config(bert_config) if config.special_tokens_fix: try: vocab_size = self.bert.embeddings.word_embeddings.num_embeddings except AttributeError: # reserve more space vocab_size = self.bert.word_embedding.num_embeddings + 5 self.bert.resize_token_embeddings(vocab_size + 1) predictor_dropout = config.predictor_dropout if config.predictor_dropout is not None else 0.0 self.dropout = nn.Dropout(predictor_dropout) self.classifier = nn.Linear(bert_config.hidden_size, config.vocab_size) self.detector = nn.Linear(bert_config.hidden_size, config.num_detect_classes) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, input_offsets: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, d_tags: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], Seq2LabelsOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] if input_offsets is not None: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) range_vector = get_range_vector(input_offsets.size(0), device=sequence_output.device).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) sequence_output = sequence_output[range_vector, input_offsets] logits = self.classifier(self.dropout(sequence_output)) logits_d = self.detector(sequence_output) loss = None if labels is not None and d_tags is not None: loss_labels_fct = CrossEntropyLoss(label_smoothing=self.label_smoothing) loss_d_fct = CrossEntropyLoss() loss_labels = loss_labels_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss_d = loss_d_fct(logits_d.view(-1, self.num_detect_classes), d_tags.view(-1)) loss = loss_labels + loss_d if not return_dict: output = (logits, logits_d) + outputs[2:] return ((loss,) + output) if loss is not None else output return Seq2LabelsOutput( loss=loss, logits=logits, detect_logits=logits_d, hidden_states=outputs.hidden_states, attentions=outputs.attentions, max_error_probability=torch.ones(logits.size(0), device=logits.device), )