|
from transformers import BertPreTrainedModel |
|
from typing import List |
|
import torch |
|
from torch import nn |
|
import numpy as np |
|
from transformers import AutoTokenizer,BertModel |
|
|
|
class ClassifierNER(BertPreTrainedModel): |
|
def __init__(self,config): |
|
super(ClassifierNER,self).__init__(config) |
|
self.bert = BertModel(config, add_pooling_layer=True) |
|
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.loss_fct = nn.CrossEntropyLoss() |
|
|
|
self.clf_labels= config.clf_labels |
|
self.clf_classes = len(self.clf_labels) |
|
self.clf_linear = nn.Linear(config.hidden_size,self.clf_classes) |
|
|
|
self.ner_labels = config.ner_labels |
|
self.ner_classes = len(self.ner_labels) |
|
self.ner_linear = nn.Linear(config.hidden_size,self.ner_classes) |
|
self.ner_lstm = nn.LSTM(config.hidden_size,config.hidden_size//2,dropout=config.hidden_dropout_prob,batch_first=True,bidirectional=True) |
|
def forward(self,input_ids,token_type_ids,attention_mask,clf_labels=None,ner_labels=None,**kwargs): |
|
outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,**kwargs) |
|
clf_output = outputs[1] |
|
clf_output = self.dropout(clf_output) |
|
clf_logits = self.clf_linear(clf_output) |
|
clf_loss = 0 |
|
if clf_labels is not None: |
|
clf_labels_tensor = torch.tensor(clf_labels, dtype=torch.long) |
|
clf_loss = self.loss_fct(clf_logits.view(-1, self.clf_classes), clf_labels_tensor.view(-1)) |
|
ner_output = outputs[0] |
|
ner_output = self.dropout(ner_output) |
|
lstm_output,hc = self.ner_lstm(ner_output) |
|
ner_logits = self.ner_linear(lstm_output) |
|
ner_loss = 0 |
|
if ner_labels is not None: |
|
ner_loss = self.loss_fct(ner_logits.view(-1,self.ner_classes),ner_labels.view(-1)) |
|
if clf_labels is not None or ner_labels is not None: |
|
loss = clf_loss + ner_loss |
|
return loss, clf_logits, ner_logits |
|
else: |
|
return clf_logits,ner_logits |
|
def predict(self,text): |
|
with torch.no_grad(): |
|
tokenized = self.tokenizer.encode_plus(text,truncation=True,max_length=512,return_tensors="pt",return_offsets_mapping=True) |
|
clf_prediction,ner_prediction = self(tokenized['input_ids'],tokenized['token_type_ids'],tokenized['attention_mask']) |
|
clf_prediction = self.clf_labels[str(torch.argmax(clf_prediction,dim=-1).item())] |
|
ner_prediction = self.align_predictions(text,ner_prediction,tokenized['offset_mapping']) |
|
return {"classification":clf_prediction,"entities":ner_prediction} |
|
def align_predictions(self,text,predictions,offsets): |
|
results = [] |
|
predictions = torch.argmax(predictions,dim=-1)[0].tolist() |
|
offsets = offsets[0].tolist() |
|
idx = 0 |
|
while idx < len(predictions): |
|
pred = predictions[idx] |
|
label = self.ner_labels[str(pred)] |
|
if label != "O": |
|
|
|
label = label[2:] |
|
start, end = offsets[idx] |
|
|
|
idx += 1 |
|
while ( |
|
idx < len(predictions) |
|
and self.ner_labels[str(predictions[idx])] == f"I-{label}" |
|
): |
|
_, end = offsets[idx] |
|
idx += 1 |
|
|
|
|
|
word = text[start:end] |
|
results.append( |
|
{ |
|
"label": label, |
|
"entity": word, |
|
"start": start, |
|
"end": end, |
|
} |
|
) |
|
idx += 1 |
|
return results |