AcroBERT / acrobert.py
Lihuchen's picture
Update acrobert.py
abe8eed
raw
history blame
5.02 kB
import numpy as np
from math import exp
import torch
from torch import nn
from transformers import BertTokenizer, BertForNextSentencePrediction
import utils
from maddog import Extractor
import spacy
import constant
nlp = spacy.load("en_core_web_sm")
ruleExtractor = Extractor()
kb = utils.load_acronym_kb('acronym_kb.json')
model_path='acrobert.pt'
class AcronymBERT(nn.Module):
def __init__(self, model_name="bert-base-uncased", device='cpu'):
super().__init__()
self.device = device
self.model = BertForNextSentencePrediction.from_pretrained(model_name)
self.tokenizer = BertTokenizer.from_pretrained(model_name)
def forward(self, sentence):
samples = self.tokenizer(sentence, padding=True, return_tensors='pt', truncation=True)["input_ids"]
samples = samples.to(self.device)
outputs = self.model(samples).logits
scores = nn.Softmax(dim=1)(outputs)[:, 0]
return scores
model = AcronymBERT(device='cpu')
model.load_state_dict(torch.load(model_path, map_location='cpu'))
def softmax(elements):
total = sum([exp(e) for e in elements])
return exp(elements[0]) / total
def predict(topk, model, short_form, context, batch_size, acronym_kb, device):
ori_candidate = utils.get_candidate(acronym_kb, short_form, can_num=20)
long_terms = [str.lower(can) for can in ori_candidate]
scores = cal_score(model.model, model.tokenizer, long_terms, context, batch_size, device)
#indexes = [np.argmax(scores)]
topk = min(len(scores), topk)
indexes = np.array(scores).argsort()[::-1][:topk]
names = [ori_candidate[i] for i in indexes]
confidences = [round(scores[i], 3) for i in indexes]
return names, confidences
def cal_score(model, tokenizer, long_forms, contexts, batch_size, device):
ps = list()
for index in range(0, len(long_forms), batch_size):
batch_lf = long_forms[index:index + batch_size]
batch_ctx = [contexts] * len(batch_lf)
encoding = tokenizer(batch_lf, batch_ctx, return_tensors="pt", padding=True, truncation=True, max_length=400).to(device)
outputs = model(**encoding)
logits = outputs.logits.cpu().detach().numpy()
p = [softmax(lg) for lg in logits]
ps.extend(p)
return ps
def dog_extract(sentence):
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)
return rulebased_pairs
def acrobert(sentence, model, device):
model.to(device)
#params = sum(p.numel() for p in model.parameters() if p.requires_grad)
#print(params)
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)
results = dict()
for acronym in rulebased_pairs.keys():
if rulebased_pairs[acronym][0] != '':
results[acronym] = rulebased_pairs[acronym][0]
else:
pred, scores = predict(5, model, acronym, sentence, batch_size=10, acronym_kb=kb, device=device)
output = list(zip(pred, scores))
#print(output)
results[acronym] = output
#results.append((acronym, pred[0], scores[0]))
return results
def popularity(sentence):
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)
results = list()
for acronym in rulebased_pairs.keys():
if rulebased_pairs[acronym][0] != '':
results.append((acronym, rulebased_pairs[acronym][0]))
else:
pred = utils.get_candidate(kb, acronym, can_num=1)
results.append((acronym, pred[0]))
return results
def acronym_linker(sentence, mode='acrobert', model=model, device='cpu'):
if mode == 'acrobert':
return acrobert(sentence, model, device)
if mode == 'pop':
return popularity(sentence)
raise Exception('mode name should in this list [acrobert, pop]')
if __name__ == '__main__':
#sentence = \
#"This new genome assembly and the annotation are tagged as a RefSeq genome by NCBI and thus provide substantially enhanced genomic resources for future research involving S. scovelli."
#sentence = """ There have been initiated several projects to modernize the network of ECB
#corridors, financed from ispa funds and state-guaranteed loans from international
#financial institutions."""
# sentence = """A whistleblower like monologist Mike Daisey gets targeted as a scapegoat who must
# be discredited and diminished in the public ’s eye. More often than not, PR is
# a preemptive process. Celebrity publicists are paid lots of money to keep certain
# stories out of the news."""
sentence = """
AI is a wide-ranging branch of computer science concerned with building smart machines capable of performing tasks that typically require human intelligence.
"""
results = acronym_linker(sentence)
print(results)