ner-chemical / app.py
lisaterumi's picture
Upload app.py
57d44a7
import gradio as gr
import transformers
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
# model large
model_name = "pucpr/clinicalnerpt-chemical"
model_large = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer_large = AutoTokenizer.from_pretrained(model_name)
# model base
model_name = "pucpr/clinicalnerpt-chemical"
model_base = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer_base = AutoTokenizer.from_pretrained(model_name)
# css
background_colors_entity_word = {
'ChemicalDrugs': "#fae8ff",
}
background_colors_entity_tag = {
'ChemicalDrugs': "#d946ef",
}
css = {
'entity_word': 'color:#000000;background: #xxxxxx; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 2.5; border-radius: 0.35em;',
'entity_tag': 'color:#fff;background: #xxxxxx; font-size: 0.8em; font-weight: bold; line-height: 2.5; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5em;'
}
list_EN = "<span style='"
list_EN += f"{css['entity_tag'].replace('#xxxxxx',background_colors_entity_tag['ChemicalDrugs'])};padding:0.5em;"
list_EN += "'>ChemicalDrugs</span>"
# infos
title = "BioBERTpt - Chemical entities"
description = "BioBERTpt - Chemical entities"
allow_screenshot = False
allow_flagging = False
examples = [
["Dispneia venoso central em subclavia D duplolumen recebendo solução salina e glicosada em BI."],
["Paciente com Sepse pulmonar em D8 tazocin (paciente não recebeu por 2 dias Atb)."],
["FOI REALIZADO CURSO DE ATB COM LEVOFLOXACINA POR 7 DIAS."],
]
def ner(input_text):
num = 0
for tokenizer,model in zip([tokenizer_large,tokenizer_base],[model_large,model_base]):
# tokenization
inputs = tokenizer(input_text, max_length=512, truncation=True, return_tensors="pt")
tokens = inputs.tokens()
# get predictions
outputs = model(**inputs).logits
predictions = torch.argmax(outputs, dim=2)
preds = [model_base.config.id2label[prediction] for prediction in predictions[0].numpy()]
# variables
groups_pred = dict()
group_indices = list()
group_label = ''
pred_prec = ''
group_start = ''
count = 0
# group the NEs
for i,en in enumerate(preds):
if en == 'O':
if len(group_indices) > 0:
groups_pred[count] = {'indices':group_indices,'en':group_label}
group_indices = list()
group_label = ''
count += 1
if en.startswith('B'):
if len(group_indices) > 0:
groups_pred[count] = {'indices':group_indices,'en':group_label}
group_indices = list()
group_label = ''
count += 1
group_indices.append(i)
group_label = en.replace('B-','')
pred_prec = en
elif en.startswith('I'):
if len(group_indices) > 0:
if en.replace('I-','') == group_label:
group_indices.append(i)
else:
groups_pred[count] = {'indices':group_indices,'en':group_label}
group_indices = [i]
group_label = en.replace('I-','')
count += 1
else:
group_indices = [i]
group_label = en.replace('I-','')
if i == len(preds) - 1 and len(group_indices) > 0:
groups_pred[count] = {'indices':group_indices,'en':group_label}
group_indices = list()
group_label = ''
count += 1
# there is at least one NE
len_groups_pred = len(groups_pred)
inputs = inputs['input_ids'][0].numpy()#[1:-1]
if len_groups_pred > 0:
for pred_num in range(len_groups_pred):
en = groups_pred[pred_num]['en']
indices = groups_pred[pred_num]['indices']
if pred_num == 0:
if indices[0] > 0:
output = tokenizer.decode(inputs[:indices[0]]) + f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> '
else:
output = f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> '
else:
output += tokenizer.decode(inputs[indices_prev[-1]+1:indices[0]]) + f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> '
indices_prev = indices
output += tokenizer.decode(inputs[indices_prev[-1]+1:])
else:
output = input_text
# output
output = output.replace('[CLS]','').replace(' [SEP]','').replace('##','')
output = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + output + "</div>"
if num == 0:
output_large = output
num += 1
else: output_base = output
return output_large, output_base
# interface gradio
iface = gr.Interface(
title=title,
description=description,
article=article,
allow_screenshot=allow_screenshot,
allow_flagging=allow_flagging,
fn=ner,
inputs=gr.inputs.Textbox(placeholder="Digite uma frase aqui ou clique em um exemplo:", lines=5),
outputs=[gr.outputs.HTML(label="NER1"),gr.outputs.HTML(label="NER2")],
examples=examples
)
iface.launch()