import gradio as gr import json import os import numpy as np import torch import transformers import tokenizers from model import BertAD DICTIONARY = json.load(open('model/dict.json')) TOKENIZER = tokenizers.BertWordPieceTokenizer(f"model/vocab.txt", lowercase=True) MAX_LEN = 256 MODEL = BertAD() vec = MODEL.state_dict()['bert.embeddings.position_ids'] chkp = torch.load(os.path.join('model', 'model_0.bin'), map_location='cpu') chkp['bert.embeddings.position_ids'] = vec MODEL.load_state_dict(chkp) del chkp, vec def sample_text(text, acronym, max_len): text = text.split() idx = text.index(acronym) left_idx = max(0, idx - max_len//2) right_idx = min(len(text), idx + max_len//2) sampled_text = text[left_idx:right_idx] return ' '.join(sampled_text) def process_data(text, acronym, expansion, tokenizer, max_len): text = str(text) expansion = str(expansion) acronym = str(acronym) n_tokens = len(text.split()) if n_tokens>120: text = sample_text(text, acronym, 120) answers = acronym + ' ' + ' '.join(DICTIONARY[acronym]) start = answers.find(expansion) end = start + len(expansion) char_mask = [0]*len(answers) for i in range(start, end): char_mask[i] = 1 tok_answer = tokenizer.encode(answers) answer_ids = tok_answer.ids answer_offsets = tok_answer.offsets answer_ids = answer_ids[1:-1] answer_offsets = answer_offsets[1:-1] target_idx = [] for i, (off1, off2) in enumerate(answer_offsets): if sum(char_mask[off1:off2])>0: target_idx.append(i) start = target_idx[0] end = target_idx[-1] text_ids = tokenizer.encode(text).ids[1:-1] token_ids = [101] + answer_ids + [102] + text_ids + [102] offsets = [(0,0)] + answer_offsets + [(0,0)]*(len(text_ids) + 2) mask = [1] * len(token_ids) token_type = [0]*(len(answer_ids) + 1) + [1]*(2+len(text_ids)) text = answers + text start = start + 1 end = end + 1 padding = max_len - len(token_ids) if padding>=0: token_ids = token_ids + ([0] * padding) token_type = token_type + [1] * padding mask = mask + ([0] * padding) offsets = offsets + ([(0, 0)] * padding) else: token_ids = token_ids[0:max_len] token_type = token_type[0:max_len] mask = mask[0:max_len] offsets = offsets[0:max_len] assert len(token_ids)==max_len assert len(mask)==max_len assert len(offsets)==max_len assert len(token_type)==max_len return { 'ids': token_ids, 'mask': mask, 'token_type': token_type, 'offset': offsets, 'start': start, 'end': end, 'text': text, 'expansion': expansion, 'acronym': acronym, } def jaccard(str1, str2): a = set(str1.lower().split()) b = set(str2.lower().split()) c = a.intersection(b) return float(len(c)) / (len(a) + len(b) - len(c)) def evaluate_jaccard(text, selected_text, acronym, offsets, idx_start, idx_end): filtered_output = "" for ix in range(idx_start, idx_end + 1): filtered_output += text[offsets[ix][0]: offsets[ix][1]] if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]: filtered_output += " " candidates = DICTIONARY[acronym] candidate_jaccards = [jaccard(w.strip(), filtered_output.strip()) for w in candidates] idx = np.argmax(candidate_jaccards) return candidate_jaccards[idx], candidates[idx] def disambiguate(text, acronym): inputs = process_data(text, acronym, acronym, TOKENIZER, MAX_LEN) ids = torch.tensor(inputs['ids']) mask = torch.tensor(inputs['mask']) token_type = torch.tensor(inputs['token_type']) offsets = inputs['offset'] expansion = inputs['expansion'] acronym = inputs['acronym'] ids = torch.unsqueeze(ids, 0) mask = torch.unsqueeze(mask, 0) token_type = torch.unsqueeze(token_type, 0) start_logits, end_logits = MODEL(ids, mask, token_type) start_prob = torch.softmax(start_logits, axis=-1).detach().numpy() end_prob = torch.softmax(end_logits, axis=-1).detach().numpy() start_idx = np.argmax(start_prob[0,:]) end_idx = np.argmax(end_prob[0,:]) _, exp = evaluate_jaccard(text, expansion, acronym, offsets, start_idx, end_idx) return exp text = gr.inputs.Textbox(lines=5, label="Context",\ default="Particularly , we explore four CNN architectures , AlexNet , GoogLeNet , VGG-16 , and ResNet to derive features for all images in our dataset , which are labeled as private or public .") acronym = gr.inputs.Dropdown(choices=sorted(list(DICTIONARY.keys())), label="Acronym", default="CNN") expansion = gr.outputs.Textbox(label="Expansion") iface = gr.Interface(fn=disambiguate, inputs=[text, acronym], outputs=expansion, \ title="Scientific Acronym Disambiguation", description="Demo of model based on https://arxiv.org/abs/2102.08818") iface.launch()