bert_based_ner / bert_ner_model_loader.py
pragnakalp's picture
Update bert_ner_model_loader.py
84a67cd
"""BERT NER Inference."""
from __future__ import absolute_import, division, print_function
import json
import os
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
import nltk
nltk.download('punkt')
from nltk import word_tokenize
# from transformers import (BertConfig, BertForTokenClassification,
# BertTokenizer)
from pytorch_transformers import (BertForTokenClassification, BertTokenizer)
class BertNer(BertForTokenClassification):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, valid_ids=None):
sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0]
batch_size,max_len,feat_dim = sequence_output.shape
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cpu')
# valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda' if torch.cuda.is_available() else 'cpu')
for i in range(batch_size):
jj = -1
for j in range(max_len):
if valid_ids[i][j].item() == 1:
jj += 1
valid_output[i][jj] = sequence_output[i][j]
sequence_output = self.dropout(valid_output)
logits = self.classifier(sequence_output)
return logits
class Ner:
def __init__(self,model_dir: str):
self.model , self.tokenizer, self.model_config = self.load_model(model_dir)
self.label_map = self.model_config["label_map"]
self.max_seq_length = self.model_config["max_seq_length"]
self.label_map = {int(k):v for k,v in self.label_map.items()}
self.device = "cpu"
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.model.to(self.device)
self.model.eval()
def load_model(self, model_dir: str, model_config: str = "model_config.json"):
model_config = os.path.join(model_dir,model_config)
model_config = json.load(open(model_config))
model = BertNer.from_pretrained(model_dir)
tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=model_config["do_lower"])
return model, tokenizer, model_config
def tokenize(self, text: str):
""" tokenize input"""
words = word_tokenize(text)
tokens = []
valid_positions = []
for i,word in enumerate(words):
token = self.tokenizer.tokenize(word)
tokens.extend(token)
for i in range(len(token)):
if i == 0:
valid_positions.append(1)
else:
valid_positions.append(0)
# print("valid positions from text o/p:=>", valid_positions)
return tokens, valid_positions
def preprocess(self, text: str):
""" preprocess """
tokens, valid_positions = self.tokenize(text)
## insert "[CLS]"
tokens.insert(0,"[CLS]")
valid_positions.insert(0,1)
## insert "[SEP]"
tokens.append("[SEP]")
valid_positions.append(1)
segment_ids = []
for i in range(len(tokens)):
segment_ids.append(0)
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# print("input ids with berttokenizer:=>", input_ids)
input_mask = [1] * len(input_ids)
while len(input_ids) < self.max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
valid_positions.append(0)
return input_ids,input_mask,segment_ids,valid_positions
def predict_entity(self, B_lab, I_lab, words, labels, entity_list):
temp=[]
entity=[]
for word, (label, confidence), B_l, I_l in zip(words, labels, B_lab, I_lab):
if ((label==B_l) or (label==I_l)) and label!='O':
if label==B_l:
entity.append(temp)
temp=[]
temp.append(label)
temp.append(word)
entity.append(temp)
# print(entity)
entity_name_label = []
for entity_name in entity[1:]:
for ent_key, ent_value in entity_list.items():
if (ent_key==entity_name[0]):
# entity_name_label.append(' '.join(entity_name[1:]) + ": " + ent_value)
entity_name_label.append([' '.join(entity_name[1:]), ent_value])
return entity_name_label
def predict(self, text: str):
input_ids,input_mask,segment_ids,valid_ids = self.preprocess(text)
# print("valid ids:=>", segment_ids)
input_ids = torch.tensor([input_ids],dtype=torch.long,device=self.device)
input_mask = torch.tensor([input_mask],dtype=torch.long,device=self.device)
segment_ids = torch.tensor([segment_ids],dtype=torch.long,device=self.device)
valid_ids = torch.tensor([valid_ids],dtype=torch.long,device=self.device)
with torch.no_grad():
logits = self.model(input_ids, segment_ids, input_mask,valid_ids)
# print("logit values:=>", logits)
logits = F.softmax(logits,dim=2)
# print("logit values:=>", logits[0])
logits_label = torch.argmax(logits,dim=2)
logits_label = logits_label.detach().cpu().numpy().tolist()[0]
# print("logits label value list:=>", logits_label)
logits_confidence = [values[label].item() for values,label in zip(logits[0],logits_label)]
logits = []
pos = 0
for index,mask in enumerate(valid_ids[0]):
if index == 0:
continue
if mask == 1:
logits.append((logits_label[index-pos],logits_confidence[index-pos]))
else:
pos += 1
logits.pop()
labels = [(self.label_map[label],confidence) for label,confidence in logits]
words = word_tokenize(text)
entity_list = {'B-PER':'Person',
'B-FAC':'Facility',
'B-LOC':'Location',
'B-ORG':'Organization',
'B-ART':'Work Of Art',
'B-EVENT':'Event',
'B-DATE':'Date-Time Entity',
'B-TIME':'Date-Time Entity',
'B-LAW':'Law Terms',
'B-PRODUCT':'Product',
'B-PERCENT':'Percentage',
'B-MONEY':'Currency',
'B-LANGUAGE':'Langauge',
'B-NORP':'Nationality / Religion / Political group',
'B-QUANTITY':'Quantity',
'B-ORDINAL':'Ordinal Number',
'B-CARDINAL':'Cardinal Number'}
B_labels=[]
I_labels=[]
for label, confidence in labels:
if (label[:1]=='B'):
B_labels.append(label)
I_labels.append('O')
elif (label[:1]=='I'):
I_labels.append(label)
B_labels.append('O')
else:
B_labels.append('O')
I_labels.append('O')
assert len(labels) == len(words) == len(I_labels) == len(B_labels)
output = self.predict_entity(B_labels, I_labels, words, labels, entity_list)
print(output)
# output = [{"word":word,"tag":label,"confidence":confidence} for word,(label,confidence) in zip(words,labels)]
return output