vibert-capu / modeling_seq2labels.py
dragonSwing's picture
Fix device bug
5859a7b
raw
history blame
5.14 kB
from typing import Any, Dict, List, Optional, Tuple, Union
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModel, BertPreTrainedModel
from transformers.modeling_outputs import ModelOutput
import torch
def get_range_vector(size: int, device: int) -> torch.Tensor:
"""
Returns a range vector with the desired size, starting at 0. The CUDA implementation
is meant to avoid copy data from CPU to GPU.
"""
return torch.arange(0, size, dtype=torch.long, device=device)
class Seq2LabelsOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
detect_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
max_error_probability: Optional[torch.FloatTensor] = None
class Seq2LabelsModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.num_detect_classes = config.num_detect_classes
self.label_smoothing = config.label_smoothing
if config.load_pretrained:
self.bert = AutoModel.from_pretrained(config.pretrained_name_or_path)
bert_config = self.bert.config
else:
bert_config = AutoConfig.from_pretrained(config.pretrained_name_or_path)
self.bert = AutoModel.from_config(bert_config)
if config.special_tokens_fix:
try:
vocab_size = self.bert.embeddings.word_embeddings.num_embeddings
except AttributeError:
# reserve more space
vocab_size = self.bert.word_embedding.num_embeddings + 5
self.bert.resize_token_embeddings(vocab_size + 1)
predictor_dropout = config.predictor_dropout if config.predictor_dropout is not None else 0.0
self.dropout = nn.Dropout(predictor_dropout)
self.classifier = nn.Linear(bert_config.hidden_size, config.vocab_size)
self.detector = nn.Linear(bert_config.hidden_size, config.num_detect_classes)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
input_offsets: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
d_tags: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2LabelsOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
if input_offsets is not None:
# offsets is (batch_size, d1, ..., dn, orig_sequence_length)
range_vector = get_range_vector(input_offsets.size(0), device=sequence_output.device).unsqueeze(1)
# selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
sequence_output = sequence_output[range_vector, input_offsets]
logits = self.classifier(self.dropout(sequence_output))
logits_d = self.detector(sequence_output)
loss = None
if labels is not None and d_tags is not None:
loss_labels_fct = CrossEntropyLoss(label_smoothing=self.label_smoothing)
loss_d_fct = CrossEntropyLoss()
loss_labels = loss_labels_fct(logits.view(-1, self.num_labels), labels.view(-1))
loss_d = loss_d_fct(logits_d.view(-1, self.num_detect_classes), d_tags.view(-1))
loss = loss_labels + loss_d
if not return_dict:
output = (logits, logits_d) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return Seq2LabelsOutput(
loss=loss,
logits=logits,
detect_logits=logits_d,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
max_error_probability=torch.ones(logits.size(0), device=logits.device),
)