File size: 4,828 Bytes
fa0a93c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import numpy as np
from math import exp
import torch
from torch import nn
from transformers import BertTokenizer, BertForNextSentencePrediction
import utils
from maddog import Extractor
import spacy
import constant
nlp = spacy.load("en_core_web_sm")
ruleExtractor = Extractor()
kb = utils.load_acronym_kb('acronym_kb.json')
model_path='acrobert.pt'
class AcronymBERT(nn.Module):
def __init__(self, model_name="bert-base-uncased", device='cpu'):
super().__init__()
self.device = device
self.model = BertForNextSentencePrediction.from_pretrained(model_name)
self.tokenizer = BertTokenizer.from_pretrained(model_name)
def forward(self, sentence):
samples = self.tokenizer(sentence, padding=True, return_tensors='pt', truncation=True)["input_ids"]
samples = samples.to(self.device)
outputs = self.model(samples).logits
scores = nn.Softmax(dim=1)(outputs)[:, 0]
return scores
model = AcronymBERT(device='cpu')
model.load_state_dict(torch.load(model_path, map_location='cpu'))
def softmax(elements):
total = sum([exp(e) for e in elements])
return exp(elements[0]) / total
def predict(topk, model, short_form, context, batch_size, acronym_kb, device):
ori_candidate = utils.get_candidate(acronym_kb, short_form, can_num=10)
long_terms = [str.lower(can) for can in ori_candidate]
scores = cal_score(model.model, model.tokenizer, long_terms, context, batch_size, device)
#indexes = [np.argmax(scores)]
topk = min(len(scores), topk)
indexes = np.array(scores).argsort()[::-1][:topk]
names = [ori_candidate[i] for i in indexes]
return names
def cal_score(model, tokenizer, long_forms, contexts, batch_size, device):
ps = list()
for index in range(0, len(long_forms), batch_size):
batch_lf = long_forms[index:index + batch_size]
batch_ctx = [contexts] * len(batch_lf)
encoding = tokenizer(batch_lf, batch_ctx, return_tensors="pt", padding=True, truncation=True, max_length=400).to(device)
outputs = model(**encoding)
logits = outputs.logits.cpu().detach().numpy()
p = [softmax(lg) for lg in logits]
ps.extend(p)
return ps
def dog_extract(sentence):
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)
return rulebased_pairs
def acrobert(sentence, model, device):
model.to(device)
#params = sum(p.numel() for p in model.parameters() if p.requires_grad)
#print(params)
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)
results = list()
for acronym in rulebased_pairs.keys():
if rulebased_pairs[acronym][0] != '':
results.append((acronym, rulebased_pairs[acronym][0]))
else:
pred = predict(1, model, acronym, sentence, batch_size=10, acronym_kb=kb, device=device)
results.append((acronym, pred[0]))
return results
def popularity(sentence):
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)
results = list()
for acronym in rulebased_pairs.keys():
if rulebased_pairs[acronym][0] != '':
results.append((acronym, rulebased_pairs[acronym][0]))
else:
pred = utils.get_candidate(kb, acronym, can_num=1)
results.append((acronym, pred[0]))
return results
def acronym_linker(sentence, mode='acrobert', model=model, device='cpu'):
if mode == 'acrobert':
return acrobert(sentence, model, device)
if mode == 'pop':
return popularity(sentence)
raise Exception('mode name should in this list [acrobert, pop]')
if __name__ == '__main__':
#sentence = \
#"This new genome assembly and the annotation are tagged as a RefSeq genome by NCBI and thus provide substantially enhanced genomic resources for future research involving S. scovelli."
#sentence = """ There have been initiated several projects to modernize the network of ECB
#corridors, financed from ispa funds and state-guaranteed loans from international
#financial institutions."""
# sentence = """A whistleblower like monologist Mike Daisey gets targeted as a scapegoat who must
# be discredited and diminished in the public ’s eye. More often than not, PR is
# a preemptive process. Celebrity publicists are paid lots of money to keep certain
# stories out of the news."""
sentence = "AI is the ability of a digital computer or computer-controlled robot to perform tasks commonly associated with intelligent beings, including NLP that processes text or document"
results = acronym_linker(sentence)
print(results) |