File size: 5,025 Bytes
fa0a93c f3f272c fa0a93c f3f272c fa0a93c f3f272c fa0a93c f3f272c fa0a93c f3f272c fa0a93c f3f272c fa0a93c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import numpy as np
from math import exp
import torch
from torch import nn
from transformers import BertTokenizer, BertForNextSentencePrediction
import utils
from maddog import Extractor
import spacy
import constant
nlp = spacy.load("en_core_web_sm")
ruleExtractor = Extractor()
kb = utils.load_acronym_kb('acronym_kb.json')
model_path='acrobert.pt'
class AcronymBERT(nn.Module):
def __init__(self, model_name="bert-base-uncased", device='cpu'):
super().__init__()
self.device = device
self.model = BertForNextSentencePrediction.from_pretrained(model_name)
self.tokenizer = BertTokenizer.from_pretrained(model_name)
def forward(self, sentence):
samples = self.tokenizer(sentence, padding=True, return_tensors='pt', truncation=True)["input_ids"]
samples = samples.to(self.device)
outputs = self.model(samples).logits
scores = nn.Softmax(dim=1)(outputs)[:, 0]
return scores
model = AcronymBERT(device='cpu')
model.load_state_dict(torch.load(model_path, map_location='cpu'))
def softmax(elements):
total = sum([exp(e) for e in elements])
return exp(elements[0]) / total
def predict(topk, model, short_form, context, batch_size, acronym_kb, device):
ori_candidate = utils.get_candidate(acronym_kb, short_form, can_num=20)
long_terms = [str.lower(can) for can in ori_candidate]
scores = cal_score(model.model, model.tokenizer, long_terms, context, batch_size, device)
#indexes = [np.argmax(scores)]
topk = min(len(scores), topk)
indexes = np.array(scores).argsort()[::-1][:topk]
names = [ori_candidate[i] for i in indexes]
confidences = [round(scores[i], 3) for i in indexes]
return names, confidences
def cal_score(model, tokenizer, long_forms, contexts, batch_size, device):
ps = list()
for index in range(0, len(long_forms), batch_size):
batch_lf = long_forms[index:index + batch_size]
batch_ctx = [contexts] * len(batch_lf)
encoding = tokenizer(batch_lf, batch_ctx, return_tensors="pt", padding=True, truncation=True, max_length=400).to(device)
outputs = model(**encoding)
logits = outputs.logits.cpu().detach().numpy()
p = [softmax(lg) for lg in logits]
ps.extend(p)
return ps
def dog_extract(sentence):
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)
return rulebased_pairs
def acrobert(sentence, model, device):
model.to(device)
#params = sum(p.numel() for p in model.parameters() if p.requires_grad)
#print(params)
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)
results = dict()
for acronym in rulebased_pairs.keys():
if rulebased_pairs[acronym][0] != '':
results.append((acronym, rulebased_pairs[acronym][0]))
else:
pred, scores = predict(5, model, acronym, sentence, batch_size=10, acronym_kb=kb, device=device)
output = list(zip(pred, scores))
#print(output)
results[acronym] = output
#results.append((acronym, pred[0], scores[0]))
return results
def popularity(sentence):
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)
results = list()
for acronym in rulebased_pairs.keys():
if rulebased_pairs[acronym][0] != '':
results.append((acronym, rulebased_pairs[acronym][0]))
else:
pred = utils.get_candidate(kb, acronym, can_num=1)
results.append((acronym, pred[0]))
return results
def acronym_linker(sentence, mode='acrobert', model=model, device='cpu'):
if mode == 'acrobert':
return acrobert(sentence, model, device)
if mode == 'pop':
return popularity(sentence)
raise Exception('mode name should in this list [acrobert, pop]')
if __name__ == '__main__':
#sentence = \
#"This new genome assembly and the annotation are tagged as a RefSeq genome by NCBI and thus provide substantially enhanced genomic resources for future research involving S. scovelli."
#sentence = """ There have been initiated several projects to modernize the network of ECB
#corridors, financed from ispa funds and state-guaranteed loans from international
#financial institutions."""
# sentence = """A whistleblower like monologist Mike Daisey gets targeted as a scapegoat who must
# be discredited and diminished in the public ’s eye. More often than not, PR is
# a preemptive process. Celebrity publicists are paid lots of money to keep certain
# stories out of the news."""
sentence = """
AI is a wide-ranging branch of computer science concerned with building smart machines capable of performing tasks that typically require human intelligence.
"""
results = acronym_linker(sentence)
print(results) |