|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import AutoConfig, AutoModel, BertPreTrainedModel |
|
from transformers.modeling_outputs import ModelOutput |
|
|
|
import torch |
|
|
|
|
|
def get_range_vector(size: int, device: int) -> torch.Tensor: |
|
""" |
|
Returns a range vector with the desired size, starting at 0. The CUDA implementation |
|
is meant to avoid copy data from CPU to GPU. |
|
""" |
|
return torch.arange(0, size, dtype=torch.long, device=device) |
|
|
|
@dataclass |
|
class Seq2LabelsOutput(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
detect_logits: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
max_error_probability: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class Seq2LabelsModel(BertPreTrainedModel): |
|
|
|
_keys_to_ignore_on_load_unexpected = [r"pooler"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.num_detect_classes = config.num_detect_classes |
|
self.label_smoothing = config.label_smoothing |
|
|
|
if config.load_pretrained: |
|
self.bert = AutoModel.from_pretrained(config.pretrained_name_or_path) |
|
bert_config = self.bert.config |
|
else: |
|
bert_config = AutoConfig.from_pretrained(config.pretrained_name_or_path) |
|
self.bert = AutoModel.from_config(bert_config) |
|
|
|
if config.special_tokens_fix: |
|
try: |
|
vocab_size = self.bert.embeddings.word_embeddings.num_embeddings |
|
except AttributeError: |
|
|
|
vocab_size = self.bert.word_embedding.num_embeddings + 5 |
|
self.bert.resize_token_embeddings(vocab_size + 1) |
|
|
|
predictor_dropout = config.predictor_dropout if config.predictor_dropout is not None else 0.0 |
|
self.dropout = nn.Dropout(predictor_dropout) |
|
self.classifier = nn.Linear(bert_config.hidden_size, config.vocab_size) |
|
self.detector = nn.Linear(bert_config.hidden_size, config.num_detect_classes) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
input_offsets: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
d_tags: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], Seq2LabelsOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
if input_offsets is not None: |
|
|
|
range_vector = get_range_vector(input_offsets.size(0), device=sequence_output.device).unsqueeze(1) |
|
|
|
sequence_output = sequence_output[range_vector, input_offsets] |
|
|
|
logits = self.classifier(self.dropout(sequence_output)) |
|
logits_d = self.detector(sequence_output) |
|
|
|
loss = None |
|
if labels is not None and d_tags is not None: |
|
loss_labels_fct = CrossEntropyLoss(label_smoothing=self.label_smoothing) |
|
loss_d_fct = CrossEntropyLoss() |
|
loss_labels = loss_labels_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
loss_d = loss_d_fct(logits_d.view(-1, self.num_detect_classes), d_tags.view(-1)) |
|
loss = loss_labels + loss_d |
|
|
|
if not return_dict: |
|
output = (logits, logits_d) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return Seq2LabelsOutput( |
|
loss=loss, |
|
logits=logits, |
|
detect_logits=logits_d, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
max_error_probability=torch.ones(logits.size(0), device=logits.device), |
|
) |
|
|