import numpy as np from math import exp import torch from torch import nn from transformers import BertTokenizer, BertForNextSentencePrediction import utils from maddog import Extractor import spacy import constant nlp = spacy.load("en_core_web_sm") ruleExtractor = Extractor() kb = utils.load_acronym_kb('acronym_kb.json') model_path='acrobert.pt' class AcronymBERT(nn.Module): def __init__(self, model_name="bert-base-uncased", device='cpu'): super().__init__() self.device = device self.model = BertForNextSentencePrediction.from_pretrained(model_name) self.tokenizer = BertTokenizer.from_pretrained(model_name) def forward(self, sentence): samples = self.tokenizer(sentence, padding=True, return_tensors='pt', truncation=True)["input_ids"] samples = samples.to(self.device) outputs = self.model(samples).logits scores = nn.Softmax(dim=1)(outputs)[:, 0] return scores model = AcronymBERT(device='cpu') model.load_state_dict(torch.load(model_path, map_location='cpu')) def softmax(elements): total = sum([exp(e) for e in elements]) return exp(elements[0]) / total def predict(topk, model, short_form, context, batch_size, acronym_kb, device): ori_candidate = utils.get_candidate(acronym_kb, short_form, can_num=20) long_terms = [str.lower(can) for can in ori_candidate] scores = cal_score(model.model, model.tokenizer, long_terms, context, batch_size, device) #indexes = [np.argmax(scores)] topk = min(len(scores), topk) indexes = np.array(scores).argsort()[::-1][:topk] names = [ori_candidate[i] for i in indexes] confidences = [round(scores[i], 3) for i in indexes] return names, confidences def cal_score(model, tokenizer, long_forms, contexts, batch_size, device): ps = list() for index in range(0, len(long_forms), batch_size): batch_lf = long_forms[index:index + batch_size] batch_ctx = [contexts] * len(batch_lf) encoding = tokenizer(batch_lf, batch_ctx, return_tensors="pt", padding=True, truncation=True, max_length=400).to(device) outputs = model(**encoding) logits = outputs.logits.cpu().detach().numpy() p = [softmax(lg) for lg in logits] ps.extend(p) return ps def dog_extract(sentence): tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0] rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES) return rulebased_pairs def acrobert(sentence, model, device): model.to(device) #params = sum(p.numel() for p in model.parameters() if p.requires_grad) #print(params) tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0] rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES) results = dict() for acronym in rulebased_pairs.keys(): if rulebased_pairs[acronym][0] != '': results[acronym] = rulebased_pairs[acronym][0] else: pred, scores = predict(5, model, acronym, sentence, batch_size=10, acronym_kb=kb, device=device) output = list(zip(pred, scores)) #print(output) results[acronym] = output #results.append((acronym, pred[0], scores[0])) return results def popularity(sentence): tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0] rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES) results = list() for acronym in rulebased_pairs.keys(): if rulebased_pairs[acronym][0] != '': results.append((acronym, rulebased_pairs[acronym][0])) else: pred = utils.get_candidate(kb, acronym, can_num=1) results.append((acronym, pred[0])) return results def acronym_linker(sentence, mode='acrobert', model=model, device='cpu'): if mode == 'acrobert': return acrobert(sentence, model, device) if mode == 'pop': return popularity(sentence) raise Exception('mode name should in this list [acrobert, pop]') if __name__ == '__main__': #sentence = \ #"This new genome assembly and the annotation are tagged as a RefSeq genome by NCBI and thus provide substantially enhanced genomic resources for future research involving S. scovelli." #sentence = """ There have been initiated several projects to modernize the network of ECB #corridors, financed from ispa funds and state-guaranteed loans from international #financial institutions.""" # sentence = """A whistleblower like monologist Mike Daisey gets targeted as a scapegoat who must # be discredited and diminished in the public ’s eye. More often than not, PR is # a preemptive process. Celebrity publicists are paid lots of money to keep certain # stories out of the news.""" sentence = """ AI is a wide-ranging branch of computer science concerned with building smart machines capable of performing tasks that typically require human intelligence. """ results = acronym_linker(sentence) print(results)