Spaces:
Runtime error
Runtime error
from transformers import PreTrainedModel,BertModel | |
from torch import nn | |
from transformers.configuration_utils import PretrainedConfig | |
from ..crf import CRF | |
from .configuration_bert import BertCrfConfig | |
class BertCrfModel(PreTrainedModel): | |
"""BERT LSTM CRF Classify | |
Args: | |
PreTrainedModel (BertConfig): config | |
Returns: | |
loss: (torch.Tensor) batch loss | |
(best_path, labels): crf best path with true labels | |
""" | |
config_class = BertCrfConfig | |
def __init__(self, config, num_tags = None): | |
super().__init__(config) | |
if num_tags is not None: | |
config.num_tags = num_tags | |
self.bert = BertModel(config=config, add_pooling_layer=False) | |
self.lstm = nn.LSTM(config.hidden_size, config.lstm_hidden_state, 1, batch_first=True, bidirectional=True) | |
self.crf = CRF(config.num_tags) | |
self.fc = nn.Linear(config.lstm_hidden_state*2, config.num_tags) | |
def forward(self, input_ids, attention_mask, token_type_ids, input_mask, labels=None): | |
outputs = self.bert( | |
input_ids = input_ids, | |
attention_mask = attention_mask, | |
token_type_ids = token_type_ids | |
) | |
hidden_states = outputs[0] | |
lstm_hidden_states = self.lstm(hidden_states)[0] | |
emission_scores = self.fc(lstm_hidden_states) | |
loss = None | |
if labels is not None: | |
loss = self.crf.loss(emission_scores, labels, input_mask==0) | |
_,best_path = self.crf(emission_scores, input_mask==0) | |
return loss,(list(i[1:-1] for i in best_path), labels.cpu() if labels is not None else None) |