Spaces:
Runtime error
Runtime error
import numpy as np | |
from math import exp | |
import torch | |
from torch import nn | |
from transformers import BertTokenizer, BertForNextSentencePrediction | |
import utils | |
from maddog import Extractor | |
import spacy | |
import constant | |
nlp = spacy.load("en_core_web_sm") | |
ruleExtractor = Extractor() | |
kb = utils.load_acronym_kb('acronym_kb.json') | |
model_path='acrobert.pt' | |
class AcronymBERT(nn.Module): | |
def __init__(self, model_name="bert-base-uncased", device='cpu'): | |
super().__init__() | |
self.device = device | |
self.model = BertForNextSentencePrediction.from_pretrained(model_name) | |
self.tokenizer = BertTokenizer.from_pretrained(model_name) | |
def forward(self, sentence): | |
samples = self.tokenizer(sentence, padding=True, return_tensors='pt', truncation=True)["input_ids"] | |
samples = samples.to(self.device) | |
outputs = self.model(samples).logits | |
scores = nn.Softmax(dim=1)(outputs)[:, 0] | |
return scores | |
model = AcronymBERT(device='cpu') | |
model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
def softmax(elements): | |
total = sum([exp(e) for e in elements]) | |
return exp(elements[0]) / total | |
def predict(topk, model, short_form, context, batch_size, acronym_kb, device): | |
ori_candidate = utils.get_candidate(acronym_kb, short_form, can_num=20) | |
long_terms = [str.lower(can) for can in ori_candidate] | |
scores = cal_score(model.model, model.tokenizer, long_terms, context, batch_size, device) | |
#indexes = [np.argmax(scores)] | |
topk = min(len(scores), topk) | |
indexes = np.array(scores).argsort()[::-1][:topk] | |
names = [ori_candidate[i] for i in indexes] | |
confidences = [round(scores[i], 3) for i in indexes] | |
return names, confidences | |
def cal_score(model, tokenizer, long_forms, contexts, batch_size, device): | |
ps = list() | |
for index in range(0, len(long_forms), batch_size): | |
batch_lf = long_forms[index:index + batch_size] | |
batch_ctx = [contexts] * len(batch_lf) | |
encoding = tokenizer(batch_lf, batch_ctx, return_tensors="pt", padding=True, truncation=True, max_length=400).to(device) | |
outputs = model(**encoding) | |
logits = outputs.logits.cpu().detach().numpy() | |
p = [softmax(lg) for lg in logits] | |
ps.extend(p) | |
return ps | |
def dog_extract(sentence): | |
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0] | |
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES) | |
return rulebased_pairs | |
def acrobert(sentence, model, device): | |
model.to(device) | |
#params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
#print(params) | |
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0] | |
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES) | |
results = dict() | |
for acronym in rulebased_pairs.keys(): | |
if rulebased_pairs[acronym][0] != '': | |
results[acronym] = rulebased_pairs[acronym][0] | |
else: | |
pred, scores = predict(5, model, acronym, sentence, batch_size=10, acronym_kb=kb, device=device) | |
output = list(zip(pred, scores)) | |
#print(output) | |
results[acronym] = output | |
#results.append((acronym, pred[0], scores[0])) | |
return results | |
def popularity(sentence): | |
tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0] | |
rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES) | |
results = list() | |
for acronym in rulebased_pairs.keys(): | |
if rulebased_pairs[acronym][0] != '': | |
results.append((acronym, rulebased_pairs[acronym][0])) | |
else: | |
pred = utils.get_candidate(kb, acronym, can_num=1) | |
results.append((acronym, pred[0])) | |
return results | |
def acronym_linker(sentence, mode='acrobert', model=model, device='cpu'): | |
if mode == 'acrobert': | |
return acrobert(sentence, model, device) | |
if mode == 'pop': | |
return popularity(sentence) | |
raise Exception('mode name should in this list [acrobert, pop]') | |
if __name__ == '__main__': | |
#sentence = \ | |
#"This new genome assembly and the annotation are tagged as a RefSeq genome by NCBI and thus provide substantially enhanced genomic resources for future research involving S. scovelli." | |
#sentence = """ There have been initiated several projects to modernize the network of ECB | |
#corridors, financed from ispa funds and state-guaranteed loans from international | |
#financial institutions.""" | |
# sentence = """A whistleblower like monologist Mike Daisey gets targeted as a scapegoat who must | |
# be discredited and diminished in the public ’s eye. More often than not, PR is | |
# a preemptive process. Celebrity publicists are paid lots of money to keep certain | |
# stories out of the news.""" | |
sentence = """ | |
AI is a wide-ranging branch of computer science concerned with building smart machines capable of performing tasks that typically require human intelligence. | |
""" | |
results = acronym_linker(sentence) | |
print(results) |