arubenruben's picture
Upload BERT_CRF
bacde3a
raw
history blame
2.66 kB
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers import BertModel, BertConfig
from transformers import AutoModelForTokenClassification, AutoConfig
from torchcrf import CRF
class BERT_CRF_Config(PretrainedConfig):
model_type = "BERT_CRF"
def __init__(self, **kwarg):
super().__init__(**kwarg)
self.model_name = "BERT_CRF"
class BERT_CRF(PreTrainedModel):
config_class = BERT_CRF_Config
def __init__(self, config):
super().__init__(config)
bert_config = BertConfig.from_pretrained(config.bert_name)
bert_config.output_attentions = True
bert_config.output_hidden_states = True
self.bert = BertModel.from_pretrained(config.bert_name, config=bert_config)
self.dropout = nn.Dropout(p=0.5)
self.linear = nn.Linear(
self.bert.config.hidden_size, config.num_labels)
self.crf = CRF(config.num_labels, batch_first=True)
def forward(self, input_ids, token_type_ids, attention_mask, labels, labels_mask):
last_hidden_layer = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[
'last_hidden_state']
last_hidden_layer = self.dropout(last_hidden_layer)
logits = self.linear(last_hidden_layer)
batch_size = logits.shape[0]
output_tags = []
if labels is not None:
loss = 0
for seq_logits, seq_labels, seq_mask in zip(logits, labels, labels_mask):
# Index logits and labels using prediction mask to pass only the
# first subtoken of each word to CRF.
seq_logits = seq_logits[seq_mask].unsqueeze(0)
seq_labels = seq_labels[seq_mask].unsqueeze(0)
if seq_logits.numel() != 0:
loss -= self.crf(seq_logits, seq_labels,
reduction='token_mean')
return loss / batch_size
else:
for seq_logits, seq_mask in zip(logits, labels_mask):
seq_logits = seq_logits[seq_mask].unsqueeze(0)
if seq_logits.numel() != 0:
tags = self.crf.decode(seq_logits)
else:
tags = [[]]
# Unpack "batch" results
output_tags.append(tags[0])
return output_tags
class ModelRegisterStep():
def __call__(self, args):
AutoConfig.register("BERT_CRF", BERT_CRF_Config)
AutoModelForTokenClassification.register(BERT_CRF_Config, BERT_CRF)
return {
**args,
}