|
import gradio as gr |
|
from transformers import AutoConfig, AutoTokenizer |
|
from bert_graph import BertForMultipleChoice |
|
import ipdb |
|
import torch |
|
import copy |
|
from itertools import chain |
|
|
|
|
|
|
|
|
|
def preprocess_function_exp(examples, tokenizer): |
|
|
|
|
|
pair_list = examples |
|
|
|
pair_len = [len(item) for item in pair_list] |
|
|
|
first_sentences = [] |
|
second_sentences = [] |
|
for line_list in pair_list: |
|
for line in line_list: |
|
|
|
sent_item = line.strip().split('\t') |
|
first_sentences.append(sent_item[0].strip()) |
|
second_sentences.append(sent_item[1].strip()) |
|
|
|
|
|
tokenized_examples = tokenizer( |
|
first_sentences, |
|
second_sentences, |
|
max_length=512, |
|
padding=False, |
|
truncation=True, |
|
) |
|
|
|
|
|
tokenized_inputs = {} |
|
for k, v in tokenized_examples.items(): |
|
flatten_list = [] |
|
head_idx = 0 |
|
tail_idx = 0 |
|
for pair_idx in pair_len: |
|
tail_idx = head_idx + pair_idx |
|
flatten_list.append(v[head_idx: tail_idx]) |
|
head_idx = copy.copy(tail_idx) |
|
tokenized_inputs[k] = flatten_list |
|
|
|
|
|
return tokenized_inputs |
|
|
|
def DCForMultipleChoice(features, tokenizer): |
|
|
|
batch_size = len(features) |
|
argument_len = 4 |
|
|
|
flattened_features = [ |
|
[{k: v[0][i] for k, v in features.items()} for i in range(4)] |
|
] |
|
|
|
flattened_features = list(chain(*flattened_features)) |
|
|
|
|
|
batch = tokenizer.pad( |
|
flattened_features, |
|
padding=True, |
|
max_length=512, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
batch = {k: v.view(1, argument_len, -1) for k, v in batch.items()} |
|
|
|
return batch |
|
|
|
def post_process_diag(predictions): |
|
|
|
num_sentences = int(len(predictions)**0.5) |
|
predictions_mtx = predictions.reshape(num_sentences, num_sentences) |
|
|
|
for i in range(num_sentences): |
|
for j in range(num_sentences): |
|
if i == j: |
|
predictions_mtx[i, j] = 0 |
|
|
|
return predictions_mtx.view(-1) |
|
|
|
def max_vote(logits1, logits2, pred1, pred2): |
|
|
|
pred1 = post_process_diag(pred1) |
|
pred2 = post_process_diag(pred2) |
|
pred_res = [] |
|
confidence_res = [] |
|
for i in range(len(logits1)): |
|
|
|
soft_logits1 = torch.nn.functional.softmax(logits1[i]) |
|
soft_logits2 = torch.nn.functional.softmax(logits2[i]) |
|
|
|
|
|
|
|
values_1, _ = soft_logits1.topk(k=2) |
|
values_2, _ = soft_logits2.topk(k=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if (values_1[0] - values_1[1]) >= (values_2[0] - values_2[1]): |
|
pred_res.append(int(pred1[i].detach().cpu().numpy())) |
|
confidence_res.append(float((values_1[0] - values_1[1]).detach().cpu().numpy())) |
|
else: |
|
pred_res.append(int(pred2[i].detach().cpu().numpy())) |
|
confidence_res.append(float((values_2[0] - values_2[1]).detach().cpu().numpy())) |
|
|
|
return pred_res, confidence_res |
|
|
|
def model_infer(input_a, input_b): |
|
|
|
config = AutoConfig.from_pretrained('michiyasunaga/BioLinkBERT-base') |
|
config.win_size = 13 |
|
config.model_mode = 'bert_mtl_1d' |
|
config.dataset_domain = 'absRCT' |
|
config.voter_branch = 'dual' |
|
config.destroy = False |
|
|
|
model = BertForMultipleChoice.from_pretrained( |
|
'michiyasunaga/BioLinkBERT-base', |
|
config=config, |
|
) |
|
p_sum = torch.load('D:/Code/Antidote/ari_model/best.pth', map_location=torch.device('cpu')) |
|
model.load_state_dict(p_sum) |
|
tokenizer = AutoTokenizer.from_pretrained('michiyasunaga/BioLinkBERT-base') |
|
|
|
|
|
examples = [[input_a+'\t'+input_a, input_a+'\t'+input_b, input_b+'\t'+input_a, input_b+'\t'+input_b]] |
|
tokenized_inputs = preprocess_function_exp(examples, tokenizer) |
|
tokenized_inputs = DCForMultipleChoice(tokenized_inputs, tokenizer) |
|
|
|
outputs = model(**tokenized_inputs) |
|
predictions, scores = max_vote(outputs.logits[0], outputs.logits[1], outputs.logits[0].argmax(dim=-1), outputs.logits[1].argmax(dim=-1)) |
|
|
|
prediction_a_b = predictions[1] |
|
prediction_b_a = predictions[2] |
|
|
|
label_space = {0: 'not relates', 1: 'supports', 2: 'attack'} |
|
label_a_b = label_space[prediction_a_b] |
|
label_b_a = label_space[prediction_b_a] |
|
|
|
return 'Head Argument {} Tail Argument'.format(label_a_b, label_b_a) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
arg_1 = gr.Textbox(label="Head Argument") |
|
arg_2 = gr.Textbox(label="Tail Argument") |
|
|
|
gr.Examples([\ |
|
"Compared with baseline measurements, both latanoprost and timolol caused a significant (P < 0.001) reduction of IOP at each hour of diurnal curve throughout the duration of therapy.",\ |
|
"Reduction of IOP was 6.0 +/- 4.5 and 5.9 +/- 4.6 with latanoprost and 4.8 +/- 3.0 and 4.6 +/- 3.1 with timolol after 6 and 12 months, respectively.",\ |
|
"Comparison of mean diurnal measurements with latanoprost and timolol showed a statistical significant (P < 0.001) difference at 3, 6, and 12 months.",\ |
|
"Mean C was found to be significantly enhanced (+30%) only in the latanoprost-treated group compared with the baseline (P = 0.017).",\ |
|
"Mean conjunctival hyperemia was graded at 0.3 in latanoprost-treated eyes and 0.2 in timolol-treated eyes.",\ |
|
"A remarkable change in iris color was observed in both eyes of 1 of the 18 patients treated with latanoprost and none of the 18 patients who received timolol.",\ |
|
"In the timolol group, heart rate was significantly reduced from 72 +/- 9 at baseline to 67 +/- 10 beats per minute at 12 months.",\ |
|
"in patients with pigmentary glaucoma, 0.005% latanoprost taken once daily was well tolerated and more effective in reducing IOP than 0.5% timolol taken twice daily.",\ |
|
"further studies may need to confirm these data on a larger sample and to evaluate the side effect of increased iris pigmentation on long-term follow-up,",\ |
|
], arg_1) |
|
gr.Examples([\ |
|
"Compared with baseline measurements, both latanoprost and timolol caused a significant (P < 0.001) reduction of IOP at each hour of diurnal curve throughout the duration of therapy.",\ |
|
"Reduction of IOP was 6.0 +/- 4.5 and 5.9 +/- 4.6 with latanoprost and 4.8 +/- 3.0 and 4.6 +/- 3.1 with timolol after 6 and 12 months, respectively.",\ |
|
"Comparison of mean diurnal measurements with latanoprost and timolol showed a statistical significant (P < 0.001) difference at 3, 6, and 12 months.",\ |
|
"Mean C was found to be significantly enhanced (+30%) only in the latanoprost-treated group compared with the baseline (P = 0.017).",\ |
|
"Mean conjunctival hyperemia was graded at 0.3 in latanoprost-treated eyes and 0.2 in timolol-treated eyes.",\ |
|
"A remarkable change in iris color was observed in both eyes of 1 of the 18 patients treated with latanoprost and none of the 18 patients who received timolol.",\ |
|
"In the timolol group, heart rate was significantly reduced from 72 +/- 9 at baseline to 67 +/- 10 beats per minute at 12 months.",\ |
|
"in patients with pigmentary glaucoma, 0.005% latanoprost taken once daily was well tolerated and more effective in reducing IOP than 0.5% timolol taken twice daily.",\ |
|
"further studies may need to confirm these data on a larger sample and to evaluate the side effect of increased iris pigmentation on long-term follow-up,",\ |
|
], arg_2) |
|
|
|
output = gr.Textbox(label="Output Box") |
|
|
|
greet_btn = gr.Button("Run") |
|
|
|
greet_btn.click(fn=model_infer, inputs=[arg_1, arg_2], outputs=output) |
|
|
|
|
|
|
|
demo.launch() |