gyr66's picture
Upload BertCrfForTokenClassification
df70323
import torch
import torch.nn as nn
from torchcrf import CRF
from transformers import BertPreTrainedModel, BertModel, BertForTokenClassification
from transformers.modeling_outputs import TokenClassifierOutput
class BertCrfForTokenClassification(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.bert = BertModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(
config.classifier_dropout
if config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.crf = CRF(config.num_labels, batch_first=True)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
dummy_logits = torch.zeros_like(logits).to(logits.device)
valid_lens = attention_mask.sum(dim=1) - 2
logits = logits[:, 1:]
labels_mask = torch.arange(logits.size(1)).to(
valid_lens.device
) < valid_lens.unsqueeze(1)
seq_label_ids = self.crf.decode(logits, mask=labels_mask)
loss = None
if labels is not None:
labels = labels[:, 1:]
is_pad = labels == -100
labels.masked_fill_(is_pad, 0)
assert torch.eq(~is_pad, labels_mask).all().item(), "mask assertion failed "
loss = -self.crf(logits, labels, mask=labels_mask, reduction="token_mean")
padded_list = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(lst) for lst in seq_label_ids],
batch_first=True,
padding_value=0,
)
padded_list = torch.nn.functional.pad(
padded_list, (0, logits.size(1) - padded_list.shape[1])
)
padded_list = torch.nn.functional.one_hot(
padded_list, num_classes=logits.size(2)
)
assert dummy_logits.size(1) == padded_list.size(1) + 1, "size assertion failed"
dummy_logits[:, 1:] = padded_list
if not return_dict:
output = (dummy_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=dummy_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)